Skip to content

Commit dcb9b49

Browse files
authored
[ENH] Transpose search api response (#5414)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Transpose the search response from row major to column major for consistency with our existing endpoints - New functionality - N/A ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent fc31e6a commit dcb9b49

File tree

9 files changed

+150
-65
lines changed

9 files changed

+150
-65
lines changed

chromadb/api/async_fastapi.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
GetResult,
4242
QueryResult,
4343
SearchResult,
44-
SearchRecord,
4544
CollectionMetadata,
4645
validate_batch,
4746
convert_np_embeddings_to_list,
@@ -420,20 +419,15 @@ async def _search(
420419
json=payload,
421420
)
422421

423-
# Parse response into SearchResult
424-
results = []
425-
for batch_results in resp_json.get("results", []):
426-
batch = []
427-
for record in batch_results:
428-
batch.append(SearchRecord(
429-
id=record["id"],
430-
document=record.get("document"),
431-
embedding=record.get("embedding"),
432-
metadata=record.get("metadata"),
433-
score=record.get("score"),
434-
))
435-
results.append(batch)
436-
return results
422+
# Return the column-major format directly
423+
return SearchResult(
424+
ids=resp_json.get("ids", []),
425+
documents=resp_json.get("documents", []),
426+
embeddings=resp_json.get("embeddings", []),
427+
metadatas=resp_json.get("metadatas", []),
428+
scores=resp_json.get("scores", []),
429+
select=resp_json.get("select", [])
430+
)
437431

438432
@trace_method("AsyncFastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
439433
@override

chromadb/api/fastapi.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
GetResult,
3333
QueryResult,
3434
SearchResult,
35-
SearchRecord,
3635
CollectionMetadata,
3736
validate_batch,
3837
convert_np_embeddings_to_list,
@@ -377,20 +376,15 @@ def _search(
377376
json=payload,
378377
)
379378

380-
# Parse response into SearchResult
381-
results = []
382-
for batch_results in resp_json.get("results", []):
383-
batch = []
384-
for record in batch_results:
385-
batch.append(SearchRecord(
386-
id=record["id"],
387-
document=record.get("document"),
388-
embedding=record.get("embedding"),
389-
metadata=record.get("metadata"),
390-
score=record.get("score"),
391-
))
392-
results.append(batch)
393-
return results
379+
# Return the column-major format directly
380+
return SearchResult(
381+
ids=resp_json.get("ids", []),
382+
documents=resp_json.get("documents", []),
383+
embeddings=resp_json.get("embeddings", []),
384+
metadatas=resp_json.get("metadatas", []),
385+
scores=resp_json.get("scores", []),
386+
select=resp_json.get("select", [])
387+
)
394388

395389
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
396390
@override

chromadb/api/models/AsyncCollection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,13 @@ async def search(
305305
- select: Select configuration for fields to return (defaults to empty)
306306
307307
Returns:
308-
SearchResult: List of search results for each search payload.
309-
Each result is a list of SearchRecord objects.
308+
SearchResult: Column-major format response with:
309+
- ids: List of result IDs for each search payload
310+
- documents: Optional documents for each payload
311+
- embeddings: Optional embeddings for each payload
312+
- metadatas: Optional metadata for each payload
313+
- scores: Optional scores for each payload
314+
- select: List of selected fields for each payload
310315
311316
Raises:
312317
NotImplementedError: For local/segment API implementations

chromadb/api/models/Collection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,13 @@ def search(
309309
- select: Select configuration for fields to return (defaults to empty)
310310
311311
Returns:
312-
SearchResult: List of search results for each search payload.
313-
Each result is a list of SearchRecord objects.
312+
SearchResult: Column-major format response with:
313+
- ids: List of result IDs for each search payload
314+
- documents: Optional documents for each payload
315+
- embeddings: Optional embeddings for each payload
316+
- metadatas: Optional metadata for each payload
317+
- scores: Optional scores for each payload
318+
- select: List of selected fields for each payload
314319
315320
Raises:
316321
NotImplementedError: For local/segment API implementations

chromadb/api/types.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
cast,
1111
Literal,
1212
get_args,
13+
TYPE_CHECKING,
1314
)
1415
from numpy.typing import NDArray
1516
import numpy as np
@@ -29,6 +30,9 @@
2930
WhereDocument,
3031
SparseVector,
3132
)
33+
34+
if TYPE_CHECKING:
35+
from chromadb.execution.expression.operator import SelectField
3236
from inspect import signature
3337
from tenacity import retry
3438
from abc import abstractmethod
@@ -44,7 +48,6 @@
4448
"WhereDocument",
4549
"UpdateCollectionMetadata",
4650
"UpdateMetadata",
47-
"SearchRecord",
4851
"SearchResult",
4952
"SparseVector",
5053
"is_valid_sparse_vector",
@@ -490,17 +493,26 @@ class QueryResult(TypedDict):
490493
included: Include
491494

492495

493-
class SearchRecord(TypedDict):
494-
"""Individual record returned from a search operation"""
495-
id: str
496-
document: Optional[str]
497-
embedding: Optional[List[float]]
498-
metadata: Optional[Metadata]
499-
score: Optional[float]
500-
501-
502-
# SearchResult is a list of results for each search payload
503-
SearchResult = List[List[SearchRecord]]
496+
class SearchResult(TypedDict):
497+
"""Column-major response from the search API matching Rust SearchResponse structure.
498+
499+
This is the format returned to users:
500+
- ids: Always present, list of result IDs for each search payload
501+
- documents: Optional per payload, None if not requested
502+
- embeddings: Optional per payload, None if not requested
503+
- metadatas: Optional per payload, None if not requested
504+
- scores: Optional per payload, None if not requested
505+
- select: List of selected fields for each payload (sorted)
506+
507+
Each top-level list index corresponds to a search payload.
508+
Within each payload, the inner lists are aligned by record index.
509+
"""
510+
ids: List[List[str]]
511+
documents: List[Optional[List[Optional[str]]]]
512+
embeddings: List[Optional[List[Optional[List[float]]]]]
513+
metadatas: List[Optional[List[Optional[Dict[str, Any]]]]]
514+
scores: List[Optional[List[Optional[float]]]]
515+
select: List[List[Union["SelectField", str]]] # List of SelectField enums or string field names for each payload
504516

505517

506518
class UpdateRequest(TypedDict):

clients/new-js/packages/chromadb/src/api/types.gen.ts

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,20 +205,21 @@ export type SearchPayload = {
205205
};
206206
};
207207

208-
export type SearchRecord = {
209-
document?: string | null;
210-
embedding?: Array<number> | null;
211-
id: string;
212-
metadata?: null | HashMap;
213-
score?: number | null;
214-
};
215-
216208
export type SearchRequestPayload = {
217209
searches: Array<SearchPayload>;
218210
};
219211

220212
export type SearchResponse = {
221-
results: Array<Array<SearchRecord>>;
213+
documents: Array<Array<string | null> | null>;
214+
embeddings: Array<Array<Array<number> | null> | null>;
215+
ids: Array<Array<string>>;
216+
metadatas: Array<Array<null | HashMap> | null>;
217+
scores: Array<Array<number | null> | null>;
218+
select: Array<Array<SelectField>>;
219+
};
220+
221+
export type SelectField = 'Document' | 'Embedding' | 'Metadata' | 'Score' | {
222+
MetadataField: string;
222223
};
223224

224225
export type SpannConfiguration = {

rust/frontend/src/impls/service_based_frontend.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,8 @@ impl ServiceBasedFrontend {
16411641
}
16421642

16431643
// Create a single Search plan with one scan and the payloads from the request
1644+
// Clone the searches to use them later for aggregating select fields
1645+
let searches_for_select = request.searches.clone();
16441646
let search_plan = Search {
16451647
scan: Scan {
16461648
collection_and_segments,
@@ -1691,13 +1693,7 @@ impl ServiceBasedFrontend {
16911693
},
16921694
}
16931695

1694-
Ok(SearchResponse {
1695-
results: result
1696-
.results
1697-
.into_iter()
1698-
.map(|result| result.records)
1699-
.collect(),
1700-
})
1696+
Ok((result, searches_for_select).into())
17011697
}
17021698

17031699
pub async fn search(&mut self, request: SearchRequest) -> Result<SearchResponse, QueryError> {

rust/types/src/api_types.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use crate::operator::GetResult;
55
use crate::operator::KnnBatchResult;
66
use crate::operator::KnnProjectionRecord;
77
use crate::operator::ProjectionRecord;
8-
use crate::operator::SearchRecord;
8+
use crate::operator::SearchResult;
9+
use crate::operator::SelectField;
910
use crate::plan::PlanToProtoError;
1011
use crate::plan::SearchPayload;
1112
use crate::validators::{
@@ -1869,7 +1870,84 @@ impl SearchRequest {
18691870

18701871
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
18711872
pub struct SearchResponse {
1872-
pub results: Vec<Vec<SearchRecord>>,
1873+
pub ids: Vec<Vec<String>>,
1874+
pub documents: Vec<Option<Vec<Option<String>>>>,
1875+
pub embeddings: Vec<Option<Vec<Option<Vec<f32>>>>>,
1876+
pub metadatas: Vec<Option<Vec<Option<Metadata>>>>,
1877+
pub scores: Vec<Option<Vec<Option<f32>>>>,
1878+
pub select: Vec<Vec<SelectField>>,
1879+
}
1880+
1881+
impl From<(SearchResult, Vec<SearchPayload>)> for SearchResponse {
1882+
fn from((result, payloads): (SearchResult, Vec<SearchPayload>)) -> Self {
1883+
let num_payloads = payloads.len();
1884+
let mut res = Self {
1885+
ids: Vec::with_capacity(num_payloads),
1886+
documents: Vec::with_capacity(num_payloads),
1887+
embeddings: Vec::with_capacity(num_payloads),
1888+
metadatas: Vec::with_capacity(num_payloads),
1889+
scores: Vec::with_capacity(num_payloads),
1890+
select: Vec::with_capacity(num_payloads),
1891+
};
1892+
1893+
for (payload_result, payload) in result.results.into_iter().zip(payloads) {
1894+
// Get the sorted select fields for this payload
1895+
let mut payload_select = Vec::from_iter(payload.select.fields.iter().cloned());
1896+
payload_select.sort();
1897+
1898+
let num_records = payload_result.records.len();
1899+
let mut ids = Vec::with_capacity(num_records);
1900+
let mut documents = Vec::with_capacity(num_records);
1901+
let mut embeddings = Vec::with_capacity(num_records);
1902+
let mut metadatas = Vec::with_capacity(num_records);
1903+
let mut scores = Vec::with_capacity(num_records);
1904+
1905+
for record in payload_result.records {
1906+
ids.push(record.id);
1907+
documents.push(record.document);
1908+
embeddings.push(record.embedding);
1909+
metadatas.push(record.metadata);
1910+
scores.push(record.score);
1911+
}
1912+
1913+
res.ids.push(ids);
1914+
res.select.push(payload_select.clone());
1915+
1916+
// Push documents if requested by this payload, otherwise None
1917+
res.documents.push(
1918+
payload_select
1919+
.binary_search(&SelectField::Document)
1920+
.is_ok()
1921+
.then_some(documents),
1922+
);
1923+
1924+
// Push embeddings if requested by this payload, otherwise None
1925+
res.embeddings.push(
1926+
payload_select
1927+
.binary_search(&SelectField::Embedding)
1928+
.is_ok()
1929+
.then_some(embeddings),
1930+
);
1931+
1932+
// Push metadatas if requested by this payload, otherwise None
1933+
// Include if either SelectField::Metadata is present or any SelectField::MetadataField(_)
1934+
let has_metadata = payload_select.binary_search(&SelectField::Metadata).is_ok()
1935+
|| payload_select
1936+
.last()
1937+
.is_some_and(|field| matches!(field, SelectField::MetadataField(_)));
1938+
res.metadatas.push(has_metadata.then_some(metadatas));
1939+
1940+
// Push scores if requested by this payload, otherwise None
1941+
res.scores.push(
1942+
payload_select
1943+
.binary_search(&SelectField::Score)
1944+
.is_ok()
1945+
.then_some(scores),
1946+
);
1947+
}
1948+
1949+
res
1950+
}
18731951
}
18741952

18751953
#[derive(Error, Debug)]

rust/types/src/execution/operator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ impl TryFrom<Rank> for chroma_proto::Rank {
942942
}
943943
}
944944

945-
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
945+
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ToSchema)]
946946
pub enum SelectField {
947947
// Predefined fields
948948
Document,

0 commit comments

Comments
 (0)