33
44import re
55
6- from typing import TYPE_CHECKING
6+ from collections .abc import Iterable
7+ from typing import TYPE_CHECKING , Any
78
89import requests
910
2324# before sending text to the reranker. This keeps inputs clean and
2425# avoids misleading the model with bracketed prefixes.
2526_TAG1 = re .compile (r"^\s*\[[^\]]*\]\s*" )
27+ DEFAULT_BOOST_WEIGHTS = {"user_id" : 0.5 , "tags" : 0.2 , "session_id" : 0.3 }
28+
29+
30+ def _value_matches (item_value : Any , wanted : Any ) -> bool :
31+ """
32+ Generic matching:
33+ - if item_value is list/tuple/set: check membership (any match if wanted is iterable)
34+ - else: equality (any match if wanted is iterable)
35+ """
36+
37+ def _iterable (x ):
38+ # exclude strings from "iterable"
39+ return isinstance (x , Iterable ) and not isinstance (x , str | bytes )
40+
41+ if _iterable (item_value ):
42+ if _iterable (wanted ):
43+ return any (w in item_value for w in wanted )
44+ return wanted in item_value
45+ else :
46+ if _iterable (wanted ):
47+ return any (item_value == w for w in wanted )
48+ return item_value == wanted
2649
2750
2851class HTTPBGEReranker (BaseReranker ):
@@ -58,6 +81,9 @@ def __init__(
5881 timeout : int = 10 ,
5982 headers_extra : dict | None = None ,
6083 rerank_source : list [str ] | None = None ,
84+ boost_weights : dict [str , float ] | None = None ,
85+ boost_default : float = 0.0 ,
86+ warn_unknown_filter_keys : bool = True ,
6187 ** kwargs ,
6288 ):
6389 """
@@ -83,6 +109,15 @@ def __init__(
83109 self .headers_extra = headers_extra or {}
84110 self .concat_source = rerank_source
85111
112+ self .boost_weights = (
113+ DEFAULT_BOOST_WEIGHTS .copy ()
114+ if boost_weights is None
115+ else {k : float (v ) for k , v in boost_weights .items ()}
116+ )
117+ self .boost_default = float (boost_default )
118+ self .warn_unknown_filter_keys = bool (warn_unknown_filter_keys )
119+ self ._warned_missing_keys : set [str ] = set ()
120+
86121 def rerank (
87122 self ,
88123 query : str ,
@@ -117,7 +152,6 @@ def rerank(
117152 # Build a mapping from "payload docs index" -> "original graph_results index"
118153 # Only include items that have a non-empty string memory. This ensures that
119154 # any index returned by the server can be mapped back correctly.
120- documents = []
121155 if self .concat_source :
122156 documents = concat_original_source (graph_results , self .concat_source )
123157 else :
@@ -155,8 +189,11 @@ def rerank(
155189 # The returned index refers to 'documents' (i.e., our 'pairs' order),
156190 # so we must map it back to the original graph_results index.
157191 if isinstance (idx , int ) and 0 <= idx < len (graph_results ):
158- score = float (r .get ("relevance_score" , r .get ("score" , 0.0 )))
159- scored_items .append ((graph_results [idx ], score ))
192+ raw_score = float (r .get ("relevance_score" , r .get ("score" , 0.0 )))
193+ item = graph_results [idx ]
194+ # generic boost
195+ score = self ._apply_boost_generic (item , raw_score , search_filter )
196+ scored_items .append ((item , score ))
160197
161198 scored_items .sort (key = lambda x : x [1 ], reverse = True )
162199 return scored_items [: min (top_k , len (scored_items ))]
@@ -172,8 +209,10 @@ def rerank(
172209 elif len (score_list ) > len (graph_results ):
173210 score_list = score_list [: len (graph_results )]
174211
175- # Map back to original items using 'pairs'
176- scored_items = list (zip (graph_results , score_list , strict = False ))
212+ scored_items = []
213+ for item , raw_score in zip (graph_results , score_list , strict = False ):
214+ score = self ._apply_boost_generic (item , raw_score , search_filter )
215+ scored_items .append ((item , score ))
177216 scored_items .sort (key = lambda x : x [1 ], reverse = True )
178217 return scored_items [: min (top_k , len (scored_items ))]
179218
@@ -187,3 +226,86 @@ def rerank(
187226 # Degrade gracefully by returning first top_k valid docs with 0.0 score.
188227 logger .error (f"[HTTPBGEReranker] request failed: { e } " )
189228 return [(item , 0.0 ) for item in graph_results [:top_k ]]
229+
230+ def _get_attr_or_key (self , obj : Any , key : str ) -> Any :
231+ """
232+ Resolve `key` on `obj` with one-level fallback into `obj.metadata`.
233+
234+ Priority:
235+ 1) obj.<key>
236+ 2) obj[key]
237+ 3) obj.metadata.<key>
238+ 4) obj.metadata[key]
239+ """
240+ if obj is None :
241+ return None
242+
243+ # support input like "metadata.user_id"
244+ if "." in key :
245+ head , tail = key .split ("." , 1 )
246+ base = self ._get_attr_or_key (obj , head )
247+ return self ._get_attr_or_key (base , tail )
248+
249+ def _resolve (o : Any , k : str ):
250+ if o is None :
251+ return None
252+ v = getattr (o , k , None )
253+ if v is not None :
254+ return v
255+ if hasattr (o , "get" ):
256+ try :
257+ return o .get (k )
258+ except Exception :
259+ return None
260+ return None
261+
262+ # 1) find in obj
263+ v = _resolve (obj , key )
264+ if v is not None :
265+ return v
266+
267+ # 2) find in obj.metadata
268+ meta = _resolve (obj , "metadata" )
269+ if meta is not None :
270+ return _resolve (meta , key )
271+
272+ return None
273+
274+ def _apply_boost_generic (
275+ self ,
276+ item : TextualMemoryItem ,
277+ base_score : float ,
278+ search_filter : dict | None ,
279+ ) -> float :
280+ """
281+ Multiply base_score by (1 + weight) for each matching key in search_filter.
282+ - key resolution: self._get_attr_or_key(item, key)
283+ - weight = boost_weights.get(key, self.boost_default)
284+ - unknown key -> one-time warning
285+ """
286+ if not search_filter :
287+ return base_score
288+
289+ score = float (base_score )
290+
291+ for key , wanted in search_filter .items ():
292+ # _get_attr_or_key automatically find key in item and
293+ # item.metadata ("metadata.user_id" supported)
294+ resolved = self ._get_attr_or_key (item , key )
295+
296+ if resolved is None :
297+ if self .warn_unknown_filter_keys and key not in self ._warned_missing_keys :
298+ logger .warning (
299+ "[HTTPBGEReranker] search_filter key '%s' not found on TextualMemoryItem or metadata" ,
300+ key ,
301+ )
302+ self ._warned_missing_keys .add (key )
303+ continue
304+
305+ if _value_matches (resolved , wanted ):
306+ w = float (self .boost_weights .get (key , self .boost_default ))
307+ if w != 0.0 :
308+ score *= 1.0 + w
309+ score = min (max (0.0 , score ), 1.0 )
310+
311+ return score
0 commit comments