Skip to content

Commit 6eec387

Browse files
authored
[BUG]: Have stats function paginate through output collection (#5906)
## Description of changes - Improvements & Bug fixes - `collection.statistics()` did not paginate through the output collection, this diff fixes that. - This diff moves all the statistics function wrappers to a file in the `chromadb/utils` directory. - New functionality - ... ## Test plan Added a new test to test_statistics_wrapper.py that was confirmed to fail before this change. - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the_ [_docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent c2ef003 commit 6eec387

File tree

3 files changed

+344
-202
lines changed

3 files changed

+344
-202
lines changed

chromadb/api/models/Collection.py

Lines changed: 0 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -552,169 +552,3 @@ def get_attached_function(self, name: str) -> "AttachedFunction":
552552
tenant=self.tenant,
553553
database=self.database,
554554
)
555-
556-
def _get_statistics_fn_name(self) -> str:
557-
"""Generate the default name for the statistics attached function.
558-
559-
Returns:
560-
str: The statistics function name
561-
"""
562-
return f"{self.name}_stats"
563-
564-
def enable_statistics(self, stats_collection_name: Optional[str] = None) -> "AttachedFunction":
565-
"""Enable statistics collection for this collection.
566-
567-
This attaches the statistics function which will automatically compute
568-
and update metadata value frequencies whenever records are added, updated,
569-
or deleted.
570-
571-
Args:
572-
stats_collection_name: Name of the collection where statistics will be stored.
573-
If None, defaults to "{collection_name}_statistics".
574-
575-
Returns:
576-
AttachedFunction: The attached statistics function
577-
578-
Example:
579-
>>> collection.enable_statistics()
580-
>>> collection.add(ids=["id1"], documents=["doc1"], metadatas=[{"key": "value"}])
581-
>>> # Statistics are automatically computed
582-
>>> stats = collection.statistics()
583-
"""
584-
if stats_collection_name is None:
585-
stats_collection_name = f"{self.name}_statistics"
586-
587-
return self.attach_function(
588-
name=self._get_statistics_fn_name(),
589-
function_id="statistics",
590-
output_collection=stats_collection_name,
591-
params=None,
592-
)
593-
594-
def _get_statistics_fn(self) -> "AttachedFunction":
595-
"""Get the statistics attached function for this collection.
596-
597-
Returns:
598-
AttachedFunction: The statistics function
599-
600-
Raises:
601-
NotFoundError: If statistics are not enabled
602-
AssertionError: If the attached function is not a statistics function
603-
"""
604-
af = self.get_attached_function(self._get_statistics_fn_name())
605-
assert af.function_name == "statistics", "Attached function is not a statistics function"
606-
return af
607-
608-
def disable_statistics(self, delete_stats_collection: bool = False) -> bool:
609-
"""Disable statistics collection for this collection.
610-
611-
Args:
612-
delete_stats_collection: If True, also delete the statistics output collection.
613-
Defaults to False.
614-
615-
Returns:
616-
bool: True if successful
617-
618-
Example:
619-
>>> collection.disable_statistics(delete_stats_collection=True)
620-
"""
621-
return self._get_statistics_fn().detach(delete_output_collection=delete_stats_collection)
622-
623-
def statistics(self) -> Dict[str, Any]:
624-
"""Get the current statistics for this collection.
625-
626-
Statistics include frequency counts for all metadata key-value pairs,
627-
as well as a summary with the total record count.
628-
629-
Returns:
630-
Dict[str, Any]: A dictionary with the structure:
631-
{
632-
"statistics": {
633-
"key1": {
634-
"value1": {"type": count, ...},
635-
"value2": {"type": count, ...}
636-
},
637-
"key2": {...},
638-
...
639-
},
640-
"summary": {
641-
"total_count": count
642-
}
643-
}
644-
645-
Example:
646-
>>> collection.enable_statistics()
647-
>>> collection.add(
648-
... ids=["id1", "id2"],
649-
... documents=["doc1", "doc2"],
650-
... metadatas=[{"category": "A", "score": 10}, {"category": "B", "score": 10}]
651-
... )
652-
>>> # Wait for statistics to be computed
653-
>>> stats = collection.statistics()
654-
>>> print(stats)
655-
{
656-
"statistics": {
657-
"category": {
658-
"A": {"str": 1},
659-
"B": {"str": 1}
660-
},
661-
"score": {
662-
"10": {"int": 2}
663-
}
664-
},
665-
"summary": {
666-
"total_count": 2
667-
}
668-
}
669-
"""
670-
from collections import defaultdict
671-
672-
af = self._get_statistics_fn()
673-
674-
# Get the statistics output collection model from the server
675-
stats_collection_model = self._client.get_collection(
676-
name=af.output_collection,
677-
tenant=self.tenant,
678-
database=self.database,
679-
)
680-
681-
# Wrap it in a Collection object to access get/query methods
682-
stats_collection = Collection(
683-
client=self._client,
684-
model=stats_collection_model,
685-
embedding_function=None, # Statistics collections don't need embedding functions
686-
data_loader=None,
687-
)
688-
689-
# Get all statistics records
690-
results = stats_collection.get(include=["metadatas"])
691-
692-
stats: Dict[str, Dict[str, Dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
693-
summary: Dict[str, Any] = {}
694-
695-
if results["metadatas"]:
696-
for metadata in results["metadatas"]:
697-
if metadata is None:
698-
continue
699-
700-
key = metadata.get("key")
701-
value = metadata.get("value")
702-
value_label = metadata.get("value_label")
703-
value_type = metadata.get("type")
704-
count = metadata.get("count")
705-
706-
if key is not None and value is not None and value_type is not None and count is not None:
707-
# Separate summary statistics from regular statistics
708-
if key == "summary":
709-
if value == "total_count":
710-
summary["total_count"] = count
711-
else:
712-
# Prioritize value_label if present, otherwise use value
713-
stats_key = value_label if value_label is not None else value
714-
stats[key][stats_key]['count'] = count
715-
716-
result = {"statistics": dict(stats)}
717-
if summary:
718-
result["summary"] = summary
719-
720-
return result

0 commit comments

Comments
 (0)