Skip to content

Commit c8ad86a

Browse files
authored
Merge pull request #27 from VowpalWabbit/cleanup
Rename scorer keys, input validation, escape newlines/tabs/etc
2 parents a45671c + a62686f commit c8ad86a

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

src/learn_to_pick/base.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ def get_default_system_prompt() -> str:
259259

260260
@staticmethod
261261
def get_default_prompt() -> str:
262-
human_template = """Given this based_on "{pick_best_selected_based_on}" \
262+
human_template = """Given this based_on "{selected_based_on}" \
263263
as the most important attribute, rank how good or bad this text is: \
264-
"{pick_best_selected}"."""
264+
"{picked}"."""
265265
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
266266
return default_system_prompt + human_template
267267

@@ -325,8 +325,8 @@ class RLLoop(Generic[TEvent]):
325325
"""
326326

327327
# Define the default values as class attributes
328-
selected_input_key = "pick_best_selected"
329-
selected_based_on_input_key = "pick_best_selected_based_on"
328+
selected_based_on_input_key = "selected_based_on"
329+
selected_input_key = "picked"
330330

331331
def __init__(
332332
self,
@@ -435,19 +435,29 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
435435
"Either a dictionary positional argument or keyword arguments should be provided"
436436
)
437437

438+
if self.selected_based_on_input_key in inputs:
439+
raise ValueError(
440+
f"The input key {self.selected_based_on_input_key} is reserved. Please use a different key."
441+
)
442+
443+
if self.selected_input_key in inputs:
444+
raise ValueError(
445+
f"The input key {self.selected_input_key} is reserved. Please use a different key."
446+
)
447+
438448
event: TEvent = self._call_before_predict(inputs=inputs)
439449
prediction = self.policy.predict(event=event)
440450
if self.metrics:
441451
self.metrics.on_decision()
442452

443-
next_chain_inputs, picked, event = self._call_after_predict_before_scoring(
453+
next_inputs, picked, event = self._call_after_predict_before_scoring(
444454
inputs=inputs, event=event, prediction=prediction
445455
)
446456

447457
for callback_func in self.callbacks_before_scoring:
448458
try:
449-
next_chain_inputs, event = callback_func(
450-
inputs=next_chain_inputs, picked=picked, event=event
459+
next_inputs, event = callback_func(
460+
inputs=next_inputs, picked=picked, event=event
451461
)
452462
except Exception as e:
453463
logger.info(f"Callback function {callback_func} failed, error: {e}")
@@ -456,7 +466,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
456466
try:
457467
if self._can_use_selection_scorer():
458468
score = self.selection_scorer.score_response(
459-
inputs=next_chain_inputs, picked=picked, event=event
469+
inputs=next_inputs, picked=picked, event=event
460470
)
461471
except Exception as e:
462472
logger.info(
@@ -471,21 +481,26 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
471481
self.policy.learn(event=event)
472482
self.policy.log(event=event)
473483

474-
event.outputs = next_chain_inputs
484+
event.outputs = next_inputs
475485
return {"picked": picked, "picked_metadata": event}
476486

477487

478488
def _embed_string_type(
479489
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
480490
) -> Dict[str, Union[str, List[str]]]:
481491
"""Helper function to embed a string or an _Embed object."""
492+
import re
493+
482494
keep_str = ""
483495
if isinstance(item, _Embed):
484496
encoded = _stringify_embedding(model.encode(item.value))
497+
# TODO these should be moved to pick_best
485498
if item.keep:
486499
keep_str = item.value.replace(" ", "_") + " "
500+
keep_str = re.sub(r"[\t\n\r\f\v]+", " ", keep_str)
487501
elif isinstance(item, str):
488502
encoded = item.replace(" ", "_")
503+
encoded = re.sub(r"[\t\n\r\f\v]+", " ", encoded)
489504
else:
490505
raise ValueError(f"Unsupported type {type(item)} for embedding")
491506

src/learn_to_pick/pick_best.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,7 @@ def _call_after_predict_before_scoring(
297297
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
298298
event.selected = selected
299299

300-
# only one key, value pair in event.to_select_from
301-
key, value = next(iter(event.to_select_from.items()))
302300
next_inputs = inputs.copy()
303-
next_inputs[key] = value[event.selected.index]
304301

305302
# only one key, value pair in event.to_select_from
306303
value = next(iter(event.to_select_from.values()))
@@ -309,12 +306,14 @@ def _call_after_predict_before_scoring(
309306
if event.selected
310307
else event.to_select_from.values()
311308
)
312-
next_inputs[self.selected_based_on_input_key] = str(event.based_on)
313-
next_inputs[self.selected_input_key] = v
309+
314310
picked = {}
315311
for k, v in event.to_select_from.items():
316312
picked[k] = v[event.selected.index]
317313

314+
next_inputs[self.selected_based_on_input_key] = str(event.based_on)
315+
next_inputs[self.selected_input_key] = str(picked)
316+
318317
return next_inputs, picked, event
319318

320319
def _call_after_scoring_before_learning(

0 commit comments

Comments
 (0)