1- #!/usr/bin/python
2- # coding: utf-8
3- from agent_utilities import get_logger
4- from .base import QueryResults
5- from typing import Any
6-
7- logger = get_logger (__name__ )
8-
91#!/usr/bin/python
102# coding: utf-8
113import inspect
124import re
13- import sys
14- import logging
15- from typing import Any
165from abc import ABC , abstractmethod
176from collections .abc import Callable , Generator , Iterable
187from contextlib import contextmanager , suppress
198from dataclasses import dataclass
209from functools import wraps
21- from logging import getLogger
22- from typing import Generic , Optional , TypeVar
23- from packaging import version
2410from typing import (
25- TYPE_CHECKING ,
26- )
27- from typing_extensions import (
28- ParamSpec ,
11+ Any ,
12+ Generic ,
13+ Optional ,
14+ TypeVar ,
2915)
16+
17+ from packaging import version
18+
19+ from agent_utilities import get_logger
20+
3021from .base import QueryResults
3122
23+ logger = get_logger (__name__ )
24+
3225__all__ = [
3326 "optional_import_block" ,
3427 "require_optional_import" ,
28+ "filter_results_by_distance" ,
29+ "chroma_results_to_query_results" ,
3530]
3631
3732
38- logger = getLogger (__name__ )
39-
40-
41- def get_logger (name : str ):
42- logger = getLogger (name )
43- logger .setLevel (logging .DEBUG )
44- logger .handlers .clear ()
45- handler = logging .StreamHandler (sys .stdout )
46- handler .setLevel (logging .DEBUG )
47- formatter = logging .Formatter (
48- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
49- )
50- handler .setFormatter (formatter )
51- logger .addHandler (handler )
52- return logger
53-
54-
5533def filter_results_by_distance (
56- results : QueryResults , distance_threshold : float = - 1
34+ results : QueryResults , distance_threshold : float = - 1
5735) -> QueryResults :
5836 """Filters results based on a distance threshold.
5937
@@ -74,7 +52,7 @@ def filter_results_by_distance(
7452
7553
7654def chroma_results_to_query_results (
77- data_dict : dict [str , list [list [Any ]]], special_key = "distances"
55+ data_dict : dict [str , list [list [Any ]]], special_key = "distances"
7856) -> QueryResults :
7957 """Converts a dictionary with list-of-list values to a list of tuples.
8058
@@ -121,8 +99,8 @@ def chroma_results_to_query_results(
12199 key
122100 for key in data_dict
123101 if key != special_key
124- and data_dict [key ] is not None
125- and isinstance (data_dict [key ][0 ], list )
102+ and data_dict [key ] is not None
103+ and isinstance (data_dict [key ][0 ], list )
126104 ]
127105 result = []
128106 data_special_key = data_dict [special_key ]
@@ -378,11 +356,11 @@ def decorator(subclass: type["PatchObject[Any]"]) -> type["PatchObject[Any]"]:
378356
379357 @classmethod
380358 def create (
381- cls ,
382- o : T ,
383- * ,
384- missing_modules : dict [str , str ],
385- dep_target : str ,
359+ cls ,
360+ o : T ,
361+ * ,
362+ missing_modules : dict [str , str ],
363+ dep_target : str ,
386364 ) -> Optional ["PatchObject[T]" ]:
387365 for subclass in cls ._registry :
388366 if subclass .accept (o ):
@@ -519,12 +497,12 @@ def patch(self, except_for: Iterable[str]) -> type[Any]:
519497
520498
521499def patch_object (
522- o : T ,
523- * ,
524- missing_modules : dict [str , str ],
525- dep_target : str ,
526- fail_if_not_patchable : bool = True ,
527- except_for : str | Iterable [str ] | None = None ,
500+ o : T ,
501+ * ,
502+ missing_modules : dict [str , str ],
503+ dep_target : str ,
504+ fail_if_not_patchable : bool = True ,
505+ except_for : str | Iterable [str ] | None = None ,
528506) -> T :
529507 patcher = PatchObject .create (
530508 o , missing_modules = missing_modules , dep_target = dep_target
@@ -539,10 +517,10 @@ def patch_object(
539517
540518
541519def require_optional_import (
542- modules : str | Iterable [str ],
543- dep_target : str ,
544- * ,
545- except_for : str | Iterable [str ] | None = None ,
520+ modules : str | Iterable [str ],
521+ dep_target : str ,
522+ * ,
523+ except_for : str | Iterable [str ] | None = None ,
546524) -> Callable [[T ], T ]:
547525 """Decorator to handle optional module dependencies
548526
@@ -569,58 +547,3 @@ def decorator(o: T) -> T:
569547 )
570548
571549 return decorator
572-
573-
574- if TYPE_CHECKING :
575- pass
576-
577- P = ParamSpec ("P" )
578- T = TypeVar ("T" )
579-
580-
581- def filter_results_by_distance (
582- results : QueryResults , distance_threshold : float = - 1
583- ) -> QueryResults :
584- """Filters results based on a distance threshold.
585-
586- Args:
587- results: QueryResults | The query results. List[List[Tuple[Document, float]]]
588- distance_threshold: The maximum distance allowed for results.
589-
590- Returns:
591- QueryResults | A filtered results containing only distances smaller than the threshold.
592- """
593- if distance_threshold > 0 :
594- results = [
595- [(key , value ) for key , value in data if value < distance_threshold ]
596- for data in results
597- ]
598-
599- return results
600-
601-
602- def chroma_results_to_query_results (
603- data_dict : dict [str , list [list [Any ]]], special_key = "distances"
604- ) -> QueryResults :
605- """Converts a dictionary with list-of-list values to a list of tuples."""
606- keys = [
607- key
608- for key in data_dict
609- if key != special_key
610- and data_dict [key ] is not None
611- and isinstance (data_dict [key ][0 ], list )
612- ]
613- result = []
614- data_special_key = data_dict [special_key ]
615-
616- for i in range (len (data_special_key )):
617- sub_result = []
618- for j , distance in enumerate (data_special_key [i ]):
619- sub_dict = {}
620- for key in keys :
621- if len (data_dict [key ]) > i :
622- sub_dict [key [:- 1 ]] = data_dict [key ][i ][j ]
623- sub_result .append ((sub_dict , distance ))
624- result .append (sub_result )
625-
626- return result
0 commit comments