@@ -259,9 +259,9 @@ def get_default_system_prompt() -> str:
259259
260260 @staticmethod
261261 def get_default_prompt () -> str :
262- human_template = """Given this based_on "{pick_best_selected_based_on }" \
262+ human_template = """Given this based_on "{selected_based_on }" \
263263 as the most important attribute, rank how good or bad this text is: \
264- "{pick_best_selected }"."""
264+ "{picked }"."""
265265 default_system_prompt = AutoSelectionScorer .get_default_system_prompt ()
266266 return default_system_prompt + human_template
267267
@@ -325,8 +325,8 @@ class RLLoop(Generic[TEvent]):
325325 """
326326
327327 # Define the default values as class attributes
328- selected_input_key = "pick_best_selected "
329- selected_based_on_input_key = "pick_best_selected_based_on "
328+ selected_based_on_input_key = "selected_based_on "
329+ selected_input_key = "picked "
330330
331331 def __init__ (
332332 self ,
@@ -435,19 +435,29 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
435435 "Either a dictionary positional argument or keyword arguments should be provided"
436436 )
437437
438+ if self .selected_based_on_input_key in inputs :
439+ raise ValueError (
440+ f"The input key { self .selected_based_on_input_key } is reserved. Please use a different key."
441+ )
442+
443+ if self .selected_input_key in inputs :
444+ raise ValueError (
445+ f"The input key { self .selected_input_key } is reserved. Please use a different key."
446+ )
447+
438448 event : TEvent = self ._call_before_predict (inputs = inputs )
439449 prediction = self .policy .predict (event = event )
440450 if self .metrics :
441451 self .metrics .on_decision ()
442452
443- next_chain_inputs , picked , event = self ._call_after_predict_before_scoring (
453+ next_inputs , picked , event = self ._call_after_predict_before_scoring (
444454 inputs = inputs , event = event , prediction = prediction
445455 )
446456
447457 for callback_func in self .callbacks_before_scoring :
448458 try :
449- next_chain_inputs , event = callback_func (
450- inputs = next_chain_inputs , picked = picked , event = event
459+ next_inputs , event = callback_func (
460+ inputs = next_inputs , picked = picked , event = event
451461 )
452462 except Exception as e :
453463 logger .info (f"Callback function { callback_func } failed, error: { e } " )
@@ -456,7 +466,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
456466 try :
457467 if self ._can_use_selection_scorer ():
458468 score = self .selection_scorer .score_response (
459- inputs = next_chain_inputs , picked = picked , event = event
469+ inputs = next_inputs , picked = picked , event = event
460470 )
461471 except Exception as e :
462472 logger .info (
@@ -471,21 +481,26 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
471481 self .policy .learn (event = event )
472482 self .policy .log (event = event )
473483
474- event .outputs = next_chain_inputs
484+ event .outputs = next_inputs
475485 return {"picked" : picked , "picked_metadata" : event }
476486
477487
478488def _embed_string_type (
479489 item : Union [str , _Embed ], model : Any , namespace : Optional [str ] = None
480490) -> Dict [str , Union [str , List [str ]]]:
481491 """Helper function to embed a string or an _Embed object."""
492+ import re
493+
482494 keep_str = ""
483495 if isinstance (item , _Embed ):
484496 encoded = _stringify_embedding (model .encode (item .value ))
497+ # TODO these should be moved to pick_best
485498 if item .keep :
486499 keep_str = item .value .replace (" " , "_" ) + " "
500+ keep_str = re .sub (r"[\t\n\r\f\v]+" , " " , keep_str )
487501 elif isinstance (item , str ):
488502 encoded = item .replace (" " , "_" )
503+ encoded = re .sub (r"[\t\n\r\f\v]+" , " " , encoded )
489504 else :
490505 raise ValueError (f"Unsupported type { type (item )} for embedding" )
491506
0 commit comments