Skip to content

Commit 0dc1a48

Browse files
authored
reafactored functions module in python sdk to subpackage (cocoindex-io#1082)
1 parent d064c52 commit 0dc1a48

File tree

4 files changed

+420
-0
lines changed

4 files changed

+420
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Functions module for cocoindex.
2+
3+
This module provides various function specifications and executors for data processing,
4+
including embedding functions, text processing, and multimodal operations.
5+
"""
6+
7+
# Import all engine builtin function specs
8+
from ._engine_builtin_specs import (
9+
ParseJson,
10+
SplitRecursively,
11+
SplitBySeparators,
12+
EmbedText,
13+
ExtractByLlm,
14+
)
15+
16+
# Import SentenceTransformer embedding functionality
17+
from .sbert import (
18+
SentenceTransformerEmbed,
19+
SentenceTransformerEmbedExecutor,
20+
)
21+
22+
# Import ColPali multimodal embedding functionality
23+
from .colpali import (
24+
ColPaliEmbedImage,
25+
ColPaliEmbedImageExecutor,
26+
ColPaliEmbedQuery,
27+
ColPaliEmbedQueryExecutor,
28+
)
29+
30+
__all__ = [
31+
# Engine builtin specs
32+
"ParseJson",
33+
"SplitRecursively",
34+
"SplitBySeparators",
35+
"EmbedText",
36+
"ExtractByLlm",
37+
# SentenceTransformer
38+
"SentenceTransformerEmbed",
39+
"SentenceTransformerEmbedExecutor",
40+
# ColPali
41+
"ColPaliEmbedImage",
42+
"ColPaliEmbedImageExecutor",
43+
"ColPaliEmbedQuery",
44+
"ColPaliEmbedQueryExecutor",
45+
]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""All builtin function specs."""
2+
3+
import dataclasses
4+
from typing import Literal
5+
6+
from .. import llm, op
7+
8+
9+
class ParseJson(op.FunctionSpec):
10+
"""Parse a text into a JSON object."""
11+
12+
13+
@dataclasses.dataclass
14+
class CustomLanguageSpec:
15+
"""Custom language specification."""
16+
17+
language_name: str
18+
separators_regex: list[str]
19+
aliases: list[str] = dataclasses.field(default_factory=list)
20+
21+
22+
class SplitRecursively(op.FunctionSpec):
23+
"""Split a document (in string) recursively."""
24+
25+
custom_languages: list[CustomLanguageSpec] = dataclasses.field(default_factory=list)
26+
27+
28+
class SplitBySeparators(op.FunctionSpec):
29+
"""
30+
Split text by specified regex separators only.
31+
Output schema matches SplitRecursively for drop-in compatibility:
32+
KTable rows with fields: location (Range), text (Str), start, end.
33+
Args:
34+
separators_regex: list[str] # e.g., [r"\\n\\n+"]
35+
keep_separator: Literal["NONE", "LEFT", "RIGHT"] = "NONE"
36+
include_empty: bool = False
37+
trim: bool = True
38+
"""
39+
40+
separators_regex: list[str] = dataclasses.field(default_factory=list)
41+
keep_separator: Literal["NONE", "LEFT", "RIGHT"] = "NONE"
42+
include_empty: bool = False
43+
trim: bool = True
44+
45+
46+
class EmbedText(op.FunctionSpec):
47+
"""Embed a text into a vector space."""
48+
49+
api_type: llm.LlmApiType
50+
model: str
51+
address: str | None = None
52+
output_dimension: int | None = None
53+
task_type: str | None = None
54+
api_config: llm.VertexAiConfig | None = None
55+
56+
57+
class ExtractByLlm(op.FunctionSpec):
58+
"""Extract information from a text using a LLM."""
59+
60+
llm_spec: llm.LlmSpec
61+
output_type: type
62+
instruction: str | None = None
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""ColPali image and query embedding functions for multimodal document retrieval."""
2+
3+
import functools
4+
from dataclasses import dataclass
5+
from typing import Any, Optional, TYPE_CHECKING, Literal
6+
import numpy as np
7+
8+
from .. import op
9+
from ..typing import Vector
10+
11+
if TYPE_CHECKING:
12+
import torch
13+
14+
15+
@dataclass
16+
class ColPaliModelInfo:
17+
"""Shared model information for ColPali embedding functions."""
18+
19+
model: Any
20+
processor: Any
21+
device: Any
22+
dimension: int
23+
24+
25+
@functools.lru_cache(maxsize=None)
26+
def _get_colpali_model_and_processor(model_name: str) -> ColPaliModelInfo:
27+
"""Load and cache ColPali model and processor with shared device setup."""
28+
try:
29+
from colpali_engine import ( # type: ignore[import-untyped]
30+
ColPali,
31+
ColPaliProcessor,
32+
ColQwen2,
33+
ColQwen2Processor,
34+
ColSmol,
35+
ColSmolProcessor,
36+
)
37+
import torch
38+
except ImportError as e:
39+
raise ImportError(
40+
"ColPali support requires the optional 'colpali' dependency. "
41+
"Install it with: pip install 'cocoindex[colpali]'"
42+
) from e
43+
44+
device = "cuda" if torch.cuda.is_available() else "cpu"
45+
46+
# Determine model type from name
47+
if "colpali" in model_name.lower():
48+
model = ColPali.from_pretrained(
49+
model_name, torch_dtype=torch.bfloat16, device_map=device
50+
)
51+
processor = ColPaliProcessor.from_pretrained(model_name)
52+
elif "colqwen" in model_name.lower():
53+
model = ColQwen2.from_pretrained(
54+
model_name, torch_dtype=torch.bfloat16, device_map=device
55+
)
56+
processor = ColQwen2Processor.from_pretrained(model_name)
57+
elif "colsmol" in model_name.lower():
58+
model = ColSmol.from_pretrained(
59+
model_name, torch_dtype=torch.bfloat16, device_map=device
60+
)
61+
processor = ColSmolProcessor.from_pretrained(model_name)
62+
else:
63+
# Fallback to ColPali for backwards compatibility
64+
model = ColPali.from_pretrained(
65+
model_name, torch_dtype=torch.bfloat16, device_map=device
66+
)
67+
processor = ColPaliProcessor.from_pretrained(model_name)
68+
69+
# Detect dimension
70+
dimension = _detect_colpali_dimension(model, processor, device)
71+
72+
return ColPaliModelInfo(
73+
model=model,
74+
processor=processor,
75+
dimension=dimension,
76+
device=device,
77+
)
78+
79+
80+
def _detect_colpali_dimension(model: Any, processor: Any, device: Any) -> int:
81+
"""Detect ColPali embedding dimension from the actual model config."""
82+
# Try to access embedding dimension
83+
if hasattr(model.config, "embedding_dim"):
84+
dim = model.config.embedding_dim
85+
else:
86+
# Fallback: infer from output shape with dummy data
87+
from PIL import Image
88+
import numpy as np
89+
import torch
90+
91+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), np.uint8))
92+
# Use the processor to process the dummy image
93+
processed = processor.process_images([dummy_img]).to(device)
94+
with torch.no_grad():
95+
output = model(**processed)
96+
dim = int(output.shape[-1])
97+
if isinstance(dim, int):
98+
return dim
99+
else:
100+
raise ValueError(f"Expected integer dimension, got {type(dim)}: {dim}")
101+
return dim
102+
103+
104+
class ColPaliEmbedImage(op.FunctionSpec):
105+
"""
106+
`ColPaliEmbedImage` embeds images using ColVision multimodal models.
107+
108+
Supports ALL models available in the colpali-engine library, including:
109+
- ColPali models (colpali-*): PaliGemma-based, best for general document retrieval
110+
- ColQwen2 models (colqwen-*): Qwen2-VL-based, excellent for multilingual text (29+ languages) and general vision
111+
- ColSmol models (colsmol-*): Lightweight, good for resource-constrained environments
112+
- Any future ColVision models supported by colpali-engine
113+
114+
These models use late interaction between image patch embeddings and text token
115+
embeddings for retrieval.
116+
117+
Args:
118+
model: Any ColVision model name supported by colpali-engine
119+
(e.g., "vidore/colpali-v1.2", "vidore/colqwen2.5-v0.2", "vidore/colsmol-v1.0")
120+
See https://github.com/illuin-tech/colpali for the complete list of supported models.
121+
122+
Note:
123+
This function requires the optional colpali-engine dependency.
124+
Install it with: pip install 'cocoindex[colpali]'
125+
"""
126+
127+
model: str
128+
129+
130+
@op.executor_class(
131+
gpu=True,
132+
cache=True,
133+
behavior_version=1,
134+
)
135+
class ColPaliEmbedImageExecutor:
136+
"""Executor for ColVision image embedding (ColPali, ColQwen2, ColSmol, etc.)."""
137+
138+
spec: ColPaliEmbedImage
139+
_model_info: ColPaliModelInfo
140+
141+
def analyze(self) -> type:
142+
# Get shared model and dimension
143+
self._model_info = _get_colpali_model_and_processor(self.spec.model)
144+
145+
# Return multi-vector type: Variable patches x Fixed hidden dimension
146+
dimension = self._model_info.dimension
147+
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
148+
149+
def __call__(self, img_bytes: bytes) -> Any:
150+
try:
151+
from PIL import Image
152+
import torch
153+
import io
154+
except ImportError as e:
155+
raise ImportError(
156+
"Required dependencies (PIL, torch) are missing for ColVision image embedding."
157+
) from e
158+
159+
model = self._model_info.model
160+
processor = self._model_info.processor
161+
device = self._model_info.device
162+
163+
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
164+
inputs = processor.process_images([pil_image]).to(device)
165+
with torch.no_grad():
166+
embeddings = model(**inputs)
167+
168+
# Return multi-vector format: [patches, hidden_dim]
169+
if len(embeddings.shape) != 3:
170+
raise ValueError(
171+
f"Expected 3D tensor [batch, patches, hidden_dim], got shape {embeddings.shape}"
172+
)
173+
174+
# Keep patch-level embeddings: [batch, patches, hidden_dim] -> [patches, hidden_dim]
175+
patch_embeddings = embeddings[0] # Remove batch dimension
176+
177+
return patch_embeddings.cpu().to(torch.float32).numpy()
178+
179+
180+
class ColPaliEmbedQuery(op.FunctionSpec):
181+
"""
182+
`ColPaliEmbedQuery` embeds text queries using ColVision multimodal models.
183+
184+
Supports ALL models available in the colpali-engine library, including:
185+
- ColPali models (colpali-*): PaliGemma-based, best for general document retrieval
186+
- ColQwen2 models (colqwen-*): Qwen2-VL-based, excellent for multilingual text (29+ languages) and general vision
187+
- ColSmol models (colsmol-*): Lightweight, good for resource-constrained environments
188+
- Any future ColVision models supported by colpali-engine
189+
190+
This produces query embeddings compatible with ColVision image embeddings
191+
for late interaction scoring (MaxSim).
192+
193+
Args:
194+
model: Any ColVision model name supported by colpali-engine
195+
(e.g., "vidore/colpali-v1.2", "vidore/colqwen2.5-v0.2", "vidore/colsmol-v1.0")
196+
See https://github.com/illuin-tech/colpali for the complete list of supported models.
197+
198+
Note:
199+
This function requires the optional colpali-engine dependency.
200+
Install it with: pip install 'cocoindex[colpali]'
201+
"""
202+
203+
model: str
204+
205+
206+
@op.executor_class(
207+
gpu=True,
208+
cache=True,
209+
behavior_version=1,
210+
)
211+
class ColPaliEmbedQueryExecutor:
212+
"""Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
213+
214+
spec: ColPaliEmbedQuery
215+
_model_info: ColPaliModelInfo
216+
217+
def analyze(self) -> type:
218+
# Get shared model and dimension
219+
self._model_info = _get_colpali_model_and_processor(self.spec.model)
220+
221+
# Return multi-vector type: Variable tokens x Fixed hidden dimension
222+
dimension = self._model_info.dimension
223+
return Vector[Vector[np.float32, Literal[dimension]]] # type: ignore
224+
225+
def __call__(self, query: str) -> Any:
226+
try:
227+
import torch
228+
except ImportError as e:
229+
raise ImportError(
230+
"Required dependencies (torch) are missing for ColVision query embedding."
231+
) from e
232+
233+
model = self._model_info.model
234+
processor = self._model_info.processor
235+
device = self._model_info.device
236+
237+
inputs = processor.process_queries([query]).to(device)
238+
with torch.no_grad():
239+
embeddings = model(**inputs)
240+
241+
# Return multi-vector format: [tokens, hidden_dim]
242+
if len(embeddings.shape) != 3:
243+
raise ValueError(
244+
f"Expected 3D tensor [batch, tokens, hidden_dim], got shape {embeddings.shape}"
245+
)
246+
247+
# Keep token-level embeddings: [batch, tokens, hidden_dim] -> [tokens, hidden_dim]
248+
token_embeddings = embeddings[0] # Remove batch dimension
249+
250+
return token_embeddings.cpu().to(torch.float32).numpy()

0 commit comments

Comments
 (0)