Skip to content

Commit 28662d2

Browse files
authored
feat: add ChromaDB target connector (Phase 1)
- Implements basic ChromaDB connector with core CRUD operations - Schema mapping for key/value fields to ChromaDB documents - Supports vector embeddings storage - Optional dependency setup - Follows LanceDB target pattern Related to #1214
1 parent 3870daa commit 28662d2

File tree

1 file changed

+348
-0
lines changed

1 file changed

+348
-0
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
import dataclasses
2+
import logging
3+
import uuid
4+
from typing import Any
5+
6+
try:
7+
import chromadb # type: ignore
8+
except ImportError as e:
9+
raise ImportError(
10+
"ChromaDB optional dependency not installed. "
11+
"Install with: pip install 'cocoindex[chromadb]'"
12+
) from e
13+
14+
from .. import op
15+
from ..typing import (
16+
FieldSchema,
17+
EnrichedValueType,
18+
BasicValueType,
19+
StructType,
20+
ValueType,
21+
TableType,
22+
)
23+
from ..index import IndexOptions
24+
25+
_logger = logging.getLogger(__name__)
26+
27+
28+
class ChromaDB(op.TargetSpec):
29+
"""ChromaDB target specification.
30+
31+
Args:
32+
collection_name: Name of the ChromaDB collection
33+
client_path: Path for persistent client (if None, uses ephemeral client)
34+
client_settings: Optional settings dict for ChromaDB client
35+
"""
36+
collection_name: str
37+
client_path: str | None = None
38+
client_settings: dict[str, Any] | None = None
39+
40+
41+
@dataclasses.dataclass
42+
class _State:
43+
key_field_schema: FieldSchema
44+
value_fields_schema: list[FieldSchema]
45+
collection_name: str
46+
client_path: str | None = None
47+
client_settings: dict[str, Any] | None = None
48+
49+
50+
@dataclasses.dataclass
51+
class _TableKey:
52+
client_path: str
53+
collection_name: str
54+
55+
56+
@dataclasses.dataclass
57+
class _MutateContext:
58+
collection: Any # chromadb.Collection
59+
key_field_schema: FieldSchema
60+
value_fields_schema: list[FieldSchema]
61+
62+
63+
def _convert_value_for_chromadb(value_type: ValueType, v: Any) -> Any:
64+
"""Convert value to ChromaDB-compatible format."""
65+
if v is None:
66+
return None
67+
68+
if isinstance(value_type, BasicValueType):
69+
# Handle UUID conversion
70+
if isinstance(v, uuid.UUID):
71+
return str(v)
72+
73+
# Handle Range type
74+
if value_type.kind == "Range":
75+
return {"start": v[0], "end": v[1]}
76+
77+
# Handle Vector type - ChromaDB stores as list of floats
78+
if value_type.vector is not None:
79+
return [float(_convert_value_for_chromadb(value_type.vector.element_type, e)) for e in v]
80+
81+
return v
82+
83+
elif isinstance(value_type, StructType):
84+
return _convert_fields_for_chromadb(value_type.fields, v)
85+
86+
elif isinstance(value_type, TableType):
87+
if isinstance(v, list):
88+
return [_convert_fields_for_chromadb(value_type.row.fields, item) for item in v]
89+
else:
90+
key_fields = value_type.row.fields[:value_type.num_key_parts]
91+
value_fields = value_type.row.fields[value_type.num_key_parts:]
92+
return [
93+
_convert_fields_for_chromadb(key_fields, item[:value_type.num_key_parts])
94+
| _convert_fields_for_chromadb(value_fields, item[value_type.num_key_parts:])
95+
for item in v
96+
]
97+
98+
return v
99+
100+
101+
def _convert_fields_for_chromadb(fields: list[FieldSchema], v: Any) -> dict:
102+
"""Convert fields to ChromaDB document format."""
103+
if isinstance(v, dict):
104+
return {
105+
field.name: _convert_value_for_chromadb(field.value_type.type, v.get(field.name))
106+
for field in fields
107+
}
108+
elif isinstance(v, tuple):
109+
return {
110+
field.name: _convert_value_for_chromadb(field.value_type.type, value)
111+
for field, value in zip(fields, v)
112+
}
113+
else:
114+
# Single value case
115+
field = fields[0]
116+
return {field.name: _convert_value_for_chromadb(field.value_type.type, v)}
117+
118+
119+
def _extract_embedding(value_dict: dict, value_fields: list[FieldSchema]) -> list[float] | None:
120+
"""Extract embedding vector from value fields if present."""
121+
for field in value_fields:
122+
if isinstance(field.value_type.type, BasicValueType):
123+
if field.value_type.type.vector is not None:
124+
vec = value_dict.get(field.name)
125+
if vec is not None:
126+
return [float(x) for x in vec]
127+
return None
128+
129+
130+
@op.target_connector(
131+
spec_cls=ChromaDB, persistent_key_type=_TableKey, setup_state_cls=_State
132+
)
133+
class _Connector:
134+
@staticmethod
135+
def get_persistent_key(spec: ChromaDB) -> _TableKey:
136+
return _TableKey(
137+
client_path=spec.client_path or ":memory:",
138+
collection_name=spec.collection_name
139+
)
140+
141+
@staticmethod
142+
def get_setup_state(
143+
spec: ChromaDB,
144+
key_fields_schema: list[FieldSchema],
145+
value_fields_schema: list[FieldSchema],
146+
index_options: IndexOptions,
147+
) -> _State:
148+
if len(key_fields_schema) != 1:
149+
raise ValueError("ChromaDB only supports a single key field")
150+
151+
if index_options.vector_indexes is not None:
152+
_logger.warning(
153+
"Vector index configuration not yet supported in ChromaDB target (Phase 1). "
154+
"Embeddings will be stored but indexing options are ignored."
155+
)
156+
157+
return _State(
158+
key_field_schema=key_fields_schema[0],
159+
value_fields_schema=value_fields_schema,
160+
collection_name=spec.collection_name,
161+
client_path=spec.client_path,
162+
client_settings=spec.client_settings,
163+
)
164+
165+
@staticmethod
166+
def describe(key: _TableKey) -> str:
167+
return f"ChromaDB collection {key.collection_name}@{key.client_path}"
168+
169+
@staticmethod
170+
def check_state_compatibility(
171+
previous: _State, current: _State
172+
) -> op.TargetStateCompatibility:
173+
if (
174+
previous.key_field_schema != current.key_field_schema
175+
or previous.value_fields_schema != current.value_fields_schema
176+
):
177+
return op.TargetStateCompatibility.NOT_COMPATIBLE
178+
return op.TargetStateCompatibility.COMPATIBLE
179+
180+
@staticmethod
181+
async def apply_setup_change(
182+
key: _TableKey, previous: _State | None, current: _State | None
183+
) -> None:
184+
latest_state = current or previous
185+
if not latest_state:
186+
return
187+
188+
# Create or connect to ChromaDB client
189+
if latest_state.client_path and latest_state.client_path != ":memory:":
190+
client = chromadb.PersistentClient(
191+
path=latest_state.client_path,
192+
settings=chromadb.Settings(**(latest_state.client_settings or {}))
193+
)
194+
else:
195+
client = chromadb.Client(
196+
settings=chromadb.Settings(**(latest_state.client_settings or {}))
197+
)
198+
199+
# Handle collection lifecycle
200+
if previous is not None and current is None:
201+
# Delete collection
202+
try:
203+
client.delete_collection(name=key.collection_name)
204+
except Exception as e:
205+
_logger.warning(
206+
"Failed to delete collection %s: %s",
207+
key.collection_name,
208+
e
209+
)
210+
return
211+
212+
if current is not None:
213+
# Check if schema changed (not compatible)
214+
reuse = previous is not None and _Connector.check_state_compatibility(
215+
previous, current
216+
) == op.TargetStateCompatibility.COMPATIBLE
217+
218+
if not reuse and previous is not None:
219+
# Schema changed, need to recreate
220+
try:
221+
client.delete_collection(name=key.collection_name)
222+
except Exception:
223+
pass # Collection might not exist
224+
225+
# Create or get collection
226+
try:
227+
collection = client.get_or_create_collection(
228+
name=current.collection_name
229+
)
230+
_logger.info(
231+
"ChromaDB collection %s ready with %d items",
232+
current.collection_name,
233+
collection.count()
234+
)
235+
except Exception as e:
236+
raise RuntimeError(
237+
f"Failed to create/open ChromaDB collection {current.collection_name}: {e}"
238+
) from e
239+
240+
@staticmethod
241+
async def prepare(
242+
spec: ChromaDB,
243+
setup_state: _State,
244+
) -> _MutateContext:
245+
# Connect to client
246+
if setup_state.client_path and setup_state.client_path != ":memory:":
247+
client = chromadb.PersistentClient(
248+
path=setup_state.client_path,
249+
settings=chromadb.Settings(**(setup_state.client_settings or {}))
250+
)
251+
else:
252+
client = chromadb.Client(
253+
settings=chromadb.Settings(**(setup_state.client_settings or {}))
254+
)
255+
256+
# Get collection
257+
collection = client.get_collection(name=spec.collection_name)
258+
259+
return _MutateContext(
260+
collection=collection,
261+
key_field_schema=setup_state.key_field_schema,
262+
value_fields_schema=setup_state.value_fields_schema,
263+
)
264+
265+
@staticmethod
266+
async def mutate(
267+
*all_mutations: tuple[_MutateContext, dict[Any, dict[str, Any] | None]],
268+
) -> None:
269+
for context, mutations in all_mutations:
270+
ids_to_upsert = []
271+
metadatas_to_upsert = []
272+
documents_to_upsert = []
273+
embeddings_to_upsert = []
274+
ids_to_delete = []
275+
276+
key_name = context.key_field_schema.name
277+
278+
for key, value in mutations.items():
279+
# Convert key to string ID
280+
if isinstance(key, uuid.UUID):
281+
key_id = str(key)
282+
else:
283+
key_id = str(key)
284+
285+
if value is None:
286+
# Deletion
287+
ids_to_delete.append(key_id)
288+
else:
289+
# Upsert
290+
ids_to_upsert.append(key_id)
291+
292+
# Convert value fields to metadata
293+
metadata = {}
294+
embedding = None
295+
document_text = None
296+
297+
for field_schema, (field_name, field_value) in zip(
298+
context.value_fields_schema, value.items()
299+
):
300+
converted = _convert_value_for_chromadb(
301+
field_schema.value_type.type, field_value
302+
)
303+
304+
# Check if this is an embedding field
305+
if isinstance(field_schema.value_type.type, BasicValueType):
306+
if field_schema.value_type.type.vector is not None:
307+
embedding = converted
308+
continue
309+
310+
# Store as metadata (ChromaDB supports str, int, float, bool)
311+
if isinstance(converted, (str, int, float, bool)):
312+
metadata[field_name] = converted
313+
elif converted is None:
314+
metadata[field_name] = None
315+
else:
316+
# Convert complex types to string
317+
import json
318+
metadata[field_name] = json.dumps(converted)
319+
320+
# Use key as document if no specific text field
321+
document_text = key_id
322+
documents_to_upsert.append(document_text)
323+
metadatas_to_upsert.append(metadata)
324+
if embedding:
325+
embeddings_to_upsert.append(embedding)
326+
327+
# Execute deletions
328+
if ids_to_delete:
329+
try:
330+
context.collection.delete(ids=ids_to_delete)
331+
except Exception as e:
332+
_logger.warning("Failed to delete some IDs: %s", e)
333+
334+
# Execute upserts
335+
if ids_to_upsert:
336+
if embeddings_to_upsert:
337+
context.collection.upsert(
338+
ids=ids_to_upsert,
339+
embeddings=embeddings_to_upsert,
340+
metadatas=metadatas_to_upsert,
341+
documents=documents_to_upsert,
342+
)
343+
else:
344+
context.collection.upsert(
345+
ids=ids_to_upsert,
346+
metadatas=metadatas_to_upsert,
347+
documents=documents_to_upsert,
348+
)

0 commit comments

Comments
 (0)