Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,17 @@ def normalize_embeddings(
if target.ndim == 1:
return [target]
elif target.ndim == 2:
return [row for row in target]
return list(target)
elif isinstance(target, list):
# One PyEmbedding
if isinstance(target[0], (int, float)) and not isinstance(target[0], bool):
first = target[0]
if isinstance(first, (int, float)) and not isinstance(first, bool):
return [np.array(target, dtype=np.float32)]
elif isinstance(target[0], np.ndarray):
elif isinstance(first, np.ndarray):
return cast(Embeddings, target)
elif isinstance(target[0], list):
if isinstance(target[0][0], (int, float)) and not isinstance(
target[0][0], bool
elif isinstance(first, list):
inner_first = first[0]
if isinstance(inner_first, (int, float)) and not isinstance(
inner_first, bool
):
return [np.array(row, dtype=np.float32) for row in target]

Expand Down Expand Up @@ -1270,10 +1271,12 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
)
if not all([isinstance(e, np.ndarray) for e in embeddings]):
# Use generator expression for memory efficiency
if not all(isinstance(e, np.ndarray) for e in embeddings):
types_seen = set(type(e).__name__ for e in embeddings)
raise ValueError(
"Expected each embedding in the embeddings to be a numpy array, got "
f"{list(set([type(e).__name__ for e in embeddings]))}"
f"{list(types_seen)}"
)
for i, embedding in enumerate(embeddings):
if embedding.ndim == 0:
Expand All @@ -1285,13 +1288,13 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
f"Expected each embedding in the embeddings to be a 1-dimensional numpy array with at least 1 int/float value. Got a 1-dimensional numpy array with no values at pos {i}"
)

if embedding.dtype not in [
if embedding.dtype not in (
np.float16,
np.float32,
np.float64,
np.int32,
np.int64,
]:
):
raise ValueError(
"Expected each value in the embedding to be a int or float, got an embedding with "
f"{embedding.dtype} - {embedding}"
Expand Down
5 changes: 3 additions & 2 deletions chromadb/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "SparseVector":
"""Deserialize from transport format (strict - requires #type field)."""
if d.get(TYPE_KEY) != SPARSE_VECTOR_TYPE_VALUE:
type_val = d[TYPE_KEY]
if type_val != SPARSE_VECTOR_TYPE_VALUE:
raise ValueError(
f"Expected {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {d.get(TYPE_KEY)}"
f"Expected {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {type_val}"
)
return cls(indices=d["indices"], values=d["values"])

Expand Down
64 changes: 29 additions & 35 deletions chromadb/execution/expression/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,15 +655,15 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
)

op = next(iter(data.keys()))
val = data[op]

if op == "$val":
value = data["$val"]
if not isinstance(value, (int, float)):
raise TypeError(f"$val requires a number, got {type(value).__name__}")
return Val(value)
if not isinstance(val, (int, float)):
raise TypeError(f"$val requires a number, got {type(val).__name__}")
return Val(val)

elif op == "$knn":
knn_data = data["$knn"]
knn_data = val
if not isinstance(knn_data, dict):
raise TypeError(f"$knn requires a dict, got {type(knn_data).__name__}")

Expand All @@ -674,7 +674,8 @@ def from_dict(data: Dict[str, Any]) -> "Rank":

if isinstance(query, dict):
# SparseVector case - deserialize from transport format
if query.get(TYPE_KEY) == SPARSE_VECTOR_TYPE_VALUE:
query_type = query.get(TYPE_KEY)
if query_type == SPARSE_VECTOR_TYPE_VALUE:
query = SparseVector.from_dict(query)
else:
# Old format or invalid - try to construct directly
Expand Down Expand Up @@ -725,7 +726,7 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
)

elif op == "$sum":
ranks_data = data["$sum"]
ranks_data = val
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$sum requires a list, got {type(ranks_data).__name__}"
Expand All @@ -734,15 +735,14 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
raise ValueError(
f"$sum requires at least 2 ranks, got {len(ranks_data)}"
)

ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result + r
# Evaluate and accumulate result, avoiding intermediate lists
result = Rank.from_dict(ranks_data[0])
for r in ranks_data[1:]:
result = result + Rank.from_dict(r)
return result

elif op == "$sub":
sub_data = data["$sub"]
sub_data = val
if not isinstance(sub_data, dict):
raise TypeError(
f"$sub requires a dict with 'left' and 'right', got {type(sub_data).__name__}"
Expand All @@ -755,7 +755,7 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
return left - right

elif op == "$mul":
ranks_data = data["$mul"]
ranks_data = val
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$mul requires a list, got {type(ranks_data).__name__}"
Expand All @@ -764,15 +764,13 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
raise ValueError(
f"$mul requires at least 2 ranks, got {len(ranks_data)}"
)

ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result * r
result = Rank.from_dict(ranks_data[0])
for r in ranks_data[1:]:
result = result * Rank.from_dict(r)
return result

elif op == "$div":
div_data = data["$div"]
div_data = val
if not isinstance(div_data, dict):
raise TypeError(
f"$div requires a dict with 'left' and 'right', got {type(div_data).__name__}"
Expand All @@ -785,31 +783,31 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
return left / right

elif op == "$abs":
child_data = data["$abs"]
child_data = val
if not isinstance(child_data, dict):
raise TypeError(
f"$abs requires a rank dict, got {type(child_data).__name__}"
)
return abs(Rank.from_dict(child_data))

elif op == "$exp":
child_data = data["$exp"]
child_data = val
if not isinstance(child_data, dict):
raise TypeError(
f"$exp requires a rank dict, got {type(child_data).__name__}"
)
return Rank.from_dict(child_data).exp()

elif op == "$log":
child_data = data["$log"]
child_data = val
if not isinstance(child_data, dict):
raise TypeError(
f"$log requires a rank dict, got {type(child_data).__name__}"
)
return Rank.from_dict(child_data).log()

elif op == "$max":
ranks_data = data["$max"]
ranks_data = val
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$max requires a list, got {type(ranks_data).__name__}"
Expand All @@ -818,15 +816,13 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
raise ValueError(
f"$max requires at least 2 ranks, got {len(ranks_data)}"
)

ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result.max(r)
result = Rank.from_dict(ranks_data[0])
for r in ranks_data[1:]:
result = result.max(Rank.from_dict(r))
return result

elif op == "$min":
ranks_data = data["$min"]
ranks_data = val
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$min requires a list, got {type(ranks_data).__name__}"
Expand All @@ -835,11 +831,9 @@ def from_dict(data: Dict[str, Any]) -> "Rank":
raise ValueError(
f"$min requires at least 2 ranks, got {len(ranks_data)}"
)

ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result.min(r)
result = Rank.from_dict(ranks_data[0])
for r in ranks_data[1:]:
result = result.min(Rank.from_dict(r))
return result

else:
Expand Down