1515 Union ,
1616)
1717
18- from learn_to_pick .metrics import (
19- MetricsTrackerAverage ,
20- MetricsTrackerRollingWindow ,
21- )
18+ from learn_to_pick .metrics import MetricsTrackerAverage , MetricsTrackerRollingWindow
2219from learn_to_pick .model_repository import ModelRepository
2320from learn_to_pick .vw_logger import VwLogger
2421
@@ -234,14 +231,12 @@ class SelectionScorer(Generic[TEvent], ABC):
234231 """
235232
236233 @abstractmethod
237- def score_response (
238- self , inputs : Dict [str , Any ], picked : Any , event : TEvent
239- ) -> Any :
234+ def score_response (self , inputs : Dict [str , Any ], picked : Any , event : TEvent ) -> Any :
240235 """
241236 Calculate and return the score for the selected response.
242237
243238 This is an abstract method and should be implemented by subclasses.
244- The method defines a blueprint for applying scoring logic based on the provided
239+ The method defines a blueprint for applying scoring logic based on the provided
245240 inputs, the selection made by the policy, and additional metadata from the event.
246241
247242 Args:
@@ -256,10 +251,12 @@ def score_response(
256251
257252
258253class AutoSelectionScorer (SelectionScorer [Event ]):
259- def __init__ (self ,
260- llm ,
261- prompt : Union [Any , None ] = None ,
262- scoring_criteria_template_str : Optional [str ] = None ):
254+ def __init__ (
255+ self ,
256+ llm ,
257+ prompt : Union [Any , None ] = None ,
258+ scoring_criteria_template_str : Optional [str ] = None ,
259+ ):
263260 self .llm = llm
264261 self .prompt = prompt
265262 if prompt is None and scoring_criteria_template_str is None :
@@ -285,16 +282,19 @@ def get_default_prompt() -> str:
285282 @staticmethod
286283 def format_with_ignoring_extra_args (prompt , inputs ):
287284 import string
285+
288286 # Extract placeholders from the prompt
289- placeholders = [field [1 ] for field in string .Formatter ().parse (str (prompt )) if field [1 ]]
287+ placeholders = [
288+ field [1 ] for field in string .Formatter ().parse (str (prompt )) if field [1 ]
289+ ]
290290
291291 # Keep only the inputs that have corresponding placeholders in the prompt
292292 relevant_inputs = {k : v for k , v in inputs .items () if k in placeholders }
293293
294294 return prompt .format (** relevant_inputs )
295295
296296 def score_response (
297- self , inputs : Dict [str , Any ], picked : Any , event : Event
297+ self , inputs : Dict [str , Any ], picked : Any , event : Event
298298 ) -> float :
299299 p = AutoSelectionScorer .format_with_ignoring_extra_args (self .prompt , inputs )
300300 ranking = self .llm .predict (p )
@@ -337,6 +337,7 @@ class RLLoop(Generic[TEvent]):
337337 Notes:
338338 By default the class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
339339 """
340+
340341 # Define the default values as class attributes
341342 selected_input_key = "pick_best_selected"
342343 selected_based_on_input_key = "pick_best_selected_based_on"
@@ -409,7 +410,6 @@ def save_progress(self) -> None:
409410 """
410411 self .policy .save ()
411412
412-
413413 def _can_use_selection_scorer (self ) -> bool :
414414 """
415415 Returns whether the chain can use the selection scorer to score responses or not.
@@ -422,10 +422,7 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
422422
423423 @abstractmethod
424424 def _call_after_predict_before_scoring (
425- self ,
426- inputs : Dict [str , Any ],
427- event : Event ,
428- prediction : List [Tuple [int , float ]],
425+ self , inputs : Dict [str , Any ], event : Event , prediction : List [Tuple [int , float ]]
429426 ) -> Tuple [Dict [str , Any ], Event ]:
430427 ...
431428
@@ -448,7 +445,9 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
448445 elif kwargs and not args :
449446 inputs = kwargs
450447 else :
451- raise ValueError ("Either a dictionary positional argument or keyword arguments should be provided" )
448+ raise ValueError (
449+ "Either a dictionary positional argument or keyword arguments should be provided"
450+ )
452451
453452 event : TEvent = self ._call_before_predict (inputs = inputs )
454453 prediction = self .policy .predict (event = event )
@@ -461,11 +460,11 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
461460
462461 for callback_func in self .callbacks_before_scoring :
463462 try :
464- next_chain_inputs , event = callback_func (inputs = next_chain_inputs , picked = picked , event = event )
465- except Exception as e :
466- logger .info (
467- f"Callback function { callback_func } failed, error: { e } "
463+ next_chain_inputs , event = callback_func (
464+ inputs = next_chain_inputs , picked = picked , event = event
468465 )
466+ except Exception as e :
467+ logger .info (f"Callback function { callback_func } failed, error: { e } " )
469468
470469 score = None
471470 try :
@@ -489,6 +488,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
489488 event .outputs = next_chain_inputs
490489 return {"picked" : picked , "picked_metadata" : event }
491490
491+
492492def is_stringtype_instance (item : Any ) -> bool :
493493 """Helper function to check if an item is a string."""
494494 return isinstance (item , str ) or (
0 commit comments