Skip to content

Commit 575fa81

Browse files
authored
Merge pull request #19 from VowpalWabbit/add_black_formatting_check
add black formatting ci check
2 parents b732272 + 873ec6d commit 575fa81

File tree

9 files changed

+80
-97
lines changed

9 files changed

+80
-97
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: Black Formatting Check
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- '*'
10+
11+
jobs:
12+
black-check:
13+
container:
14+
image: python:3.8
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- uses: actions/checkout@v2
19+
- name: Install dependencies
20+
run: |
21+
pip install .[dev]
22+
23+
- name: Check Black formatting
24+
run: |
25+
black --check .

.github/workflows/unit_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
image: python:3.8
1515
runs-on: ubuntu-latest
1616
steps:
17-
- uses: actions/checkout@v1
17+
- uses: actions/checkout@v2
1818
- name: Run Tests
1919
shell: bash
2020
run: |

setup.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,19 @@
44
name="learn_to_pick",
55
version="0.1",
66
install_requires=[
7-
'numpy',
8-
'pandas',
9-
'vowpal-wabbit-next',
10-
'sentence-transformers',
11-
'torch',
12-
'pyskiplist',
13-
'parameterfree',
7+
"numpy",
8+
"pandas",
9+
"vowpal-wabbit-next",
10+
"sentence-transformers",
11+
"torch",
12+
"pyskiplist",
13+
"parameterfree",
1414
],
15-
extras_require={
16-
'dev': [
17-
'pytest'
18-
]
19-
},
15+
extras_require={"dev": ["pytest", "black==23.10.0"]},
2016
author="VowpalWabbit",
2117
description="",
2218
packages=find_packages(where="src"),
2319
package_dir={"": "src"},
2420
url="https://github.com/VowpalWabbit/learn_to_pick",
25-
python_requires='>=3.8',
21+
python_requires=">=3.8",
2622
)

src/learn_to_pick/base.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
Union,
1616
)
1717

18-
from learn_to_pick.metrics import (
19-
MetricsTrackerAverage,
20-
MetricsTrackerRollingWindow,
21-
)
18+
from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
2219
from learn_to_pick.model_repository import ModelRepository
2320
from learn_to_pick.vw_logger import VwLogger
2421

@@ -234,14 +231,12 @@ class SelectionScorer(Generic[TEvent], ABC):
234231
"""
235232

236233
@abstractmethod
237-
def score_response(
238-
self, inputs: Dict[str, Any], picked: Any, event: TEvent
239-
) -> Any:
234+
def score_response(self, inputs: Dict[str, Any], picked: Any, event: TEvent) -> Any:
240235
"""
241236
Calculate and return the score for the selected response.
242237
243238
This is an abstract method and should be implemented by subclasses.
244-
The method defines a blueprint for applying scoring logic based on the provided
239+
The method defines a blueprint for applying scoring logic based on the provided
245240
inputs, the selection made by the policy, and additional metadata from the event.
246241
247242
Args:
@@ -256,10 +251,12 @@ def score_response(
256251

257252

258253
class AutoSelectionScorer(SelectionScorer[Event]):
259-
def __init__(self,
260-
llm,
261-
prompt: Union[Any, None] = None,
262-
scoring_criteria_template_str: Optional[str] = None):
254+
def __init__(
255+
self,
256+
llm,
257+
prompt: Union[Any, None] = None,
258+
scoring_criteria_template_str: Optional[str] = None,
259+
):
263260
self.llm = llm
264261
self.prompt = prompt
265262
if prompt is None and scoring_criteria_template_str is None:
@@ -285,16 +282,19 @@ def get_default_prompt() -> str:
285282
@staticmethod
286283
def format_with_ignoring_extra_args(prompt, inputs):
287284
import string
285+
288286
# Extract placeholders from the prompt
289-
placeholders = [field[1] for field in string.Formatter().parse(str(prompt)) if field[1]]
287+
placeholders = [
288+
field[1] for field in string.Formatter().parse(str(prompt)) if field[1]
289+
]
290290

291291
# Keep only the inputs that have corresponding placeholders in the prompt
292292
relevant_inputs = {k: v for k, v in inputs.items() if k in placeholders}
293293

294294
return prompt.format(**relevant_inputs)
295295

296296
def score_response(
297-
self, inputs: Dict[str, Any], picked: Any, event: Event
297+
self, inputs: Dict[str, Any], picked: Any, event: Event
298298
) -> float:
299299
p = AutoSelectionScorer.format_with_ignoring_extra_args(self.prompt, inputs)
300300
ranking = self.llm.predict(p)
@@ -337,6 +337,7 @@ class RLLoop(Generic[TEvent]):
337337
Notes:
338338
By default the class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
339339
"""
340+
340341
# Define the default values as class attributes
341342
selected_input_key = "pick_best_selected"
342343
selected_based_on_input_key = "pick_best_selected_based_on"
@@ -409,7 +410,6 @@ def save_progress(self) -> None:
409410
"""
410411
self.policy.save()
411412

412-
413413
def _can_use_selection_scorer(self) -> bool:
414414
"""
415415
Returns whether the chain can use the selection scorer to score responses or not.
@@ -422,10 +422,7 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
422422

423423
@abstractmethod
424424
def _call_after_predict_before_scoring(
425-
self,
426-
inputs: Dict[str, Any],
427-
event: Event,
428-
prediction: List[Tuple[int, float]],
425+
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
429426
) -> Tuple[Dict[str, Any], Event]:
430427
...
431428

@@ -448,7 +445,9 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
448445
elif kwargs and not args:
449446
inputs = kwargs
450447
else:
451-
raise ValueError("Either a dictionary positional argument or keyword arguments should be provided")
448+
raise ValueError(
449+
"Either a dictionary positional argument or keyword arguments should be provided"
450+
)
452451

453452
event: TEvent = self._call_before_predict(inputs=inputs)
454453
prediction = self.policy.predict(event=event)
@@ -461,11 +460,11 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
461460

462461
for callback_func in self.callbacks_before_scoring:
463462
try:
464-
next_chain_inputs, event = callback_func(inputs=next_chain_inputs, picked=picked, event=event)
465-
except Exception as e:
466-
logger.info(
467-
f"Callback function {callback_func} failed, error: {e}"
463+
next_chain_inputs, event = callback_func(
464+
inputs=next_chain_inputs, picked=picked, event=event
468465
)
466+
except Exception as e:
467+
logger.info(f"Callback function {callback_func} failed, error: {e}")
469468

470469
score = None
471470
try:
@@ -489,6 +488,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
489488
event.outputs = next_chain_inputs
490489
return {"picked": picked, "picked_metadata": event}
491490

491+
492492
def is_stringtype_instance(item: Any) -> bool:
493493
"""Helper function to check if an item is a string."""
494494
return isinstance(item, str) or (

src/learn_to_pick/metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def on_feedback(self, value: float) -> None:
5858
self.sum -= old_val
5959

6060
if self.step > 0 and self.feedback_count % self.step == 0:
61-
self.history.append({"step": self.feedback_count, "score": self.sum / len(self.queue)})
61+
self.history.append(
62+
{"step": self.feedback_count, "score": self.sum / len(self.queue)}
63+
)
6264

6365
def to_pandas(self) -> "pd.DataFrame":
6466
import pandas as pd

src/learn_to_pick/pick_best.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def create(
352352
logger.warning(
353353
f"{[k for k, v in policy_args.items() if v]} will be ignored since nontrivial policy is provided, please set those arguments in the policy directly if needed"
354354
)
355-
355+
356356
if policy_args["model_save_dir"] is None:
357357
policy_args["model_save_dir"] = "./"
358358
if policy_args["reset_model"] is None:
@@ -370,7 +370,7 @@ def create_policy(
370370
vw_cmd: Optional[List[str]] = None,
371371
model_save_dir: str = "./",
372372
reset_model: bool = False,
373-
rl_logs: Optional[Union[str, os.PathLike]] = None
373+
rl_logs: Optional[Union[str, os.PathLike]] = None,
374374
):
375375
if not featurizer:
376376
featurizer = PickBestFeaturizer(auto_embed=False)
@@ -384,20 +384,15 @@ def create_policy(
384384
)
385385
else:
386386
interactions += ["--interactions=::"]
387-
vw_cmd = [
388-
"--cb_explore_adf",
389-
"--coin",
390-
"--squarecb",
391-
"--quiet",
392-
]
387+
vw_cmd = ["--cb_explore_adf", "--coin", "--squarecb", "--quiet"]
393388

394389
if featurizer.auto_embed:
395390
interactions += [
396391
"--interactions=@#",
397392
"--ignore_linear=@",
398393
"--ignore_linear=#",
399394
]
400-
395+
401396
vw_cmd = interactions + vw_cmd
402397

403398
return base.VwPolicy(

tests/unit_tests/test_pick_best_call.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,8 @@ def score_response(
155155

156156

157157
def test_everything_embedded() -> None:
158-
featurizer = learn_to_pick.PickBestFeaturizer(
159-
auto_embed=False, model=MockEncoder()
160-
)
161-
pick = learn_to_pick.PickBest.create(
162-
llm=fake_llm_caller, featurizer=featurizer
163-
)
158+
featurizer = learn_to_pick.PickBestFeaturizer(auto_embed=False, model=MockEncoder())
159+
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)
164160

165161
str1 = "0"
166162
str2 = "1"
@@ -187,12 +183,8 @@ def test_everything_embedded() -> None:
187183

188184

189185
def test_default_auto_embedder_is_off() -> None:
190-
featurizer = learn_to_pick.PickBestFeaturizer(
191-
auto_embed=False, model=MockEncoder()
192-
)
193-
pick = learn_to_pick.PickBest.create(
194-
llm=fake_llm_caller, featurizer=featurizer
195-
)
186+
featurizer = learn_to_pick.PickBestFeaturizer(auto_embed=False, model=MockEncoder())
187+
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)
196188

197189
str1 = "0"
198190
str2 = "1"
@@ -213,12 +205,8 @@ def test_default_auto_embedder_is_off() -> None:
213205

214206

215207
def test_default_w_embeddings_off() -> None:
216-
featurizer = learn_to_pick.PickBestFeaturizer(
217-
auto_embed=False, model=MockEncoder()
218-
)
219-
pick = learn_to_pick.PickBest.create(
220-
llm=fake_llm_caller, featurizer=featurizer
221-
)
208+
featurizer = learn_to_pick.PickBestFeaturizer(auto_embed=False, model=MockEncoder())
209+
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)
222210

223211
str1 = "0"
224212
str2 = "1"
@@ -242,9 +230,7 @@ def test_default_w_embeddings_on() -> None:
242230
featurizer = learn_to_pick.PickBestFeaturizer(
243231
auto_embed=True, model=MockEncoderReturnsList()
244232
)
245-
pick = learn_to_pick.PickBest.create(
246-
llm=fake_llm_caller, featurizer=featurizer
247-
)
233+
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)
248234

249235
str1 = "0"
250236
str2 = "1"
@@ -268,9 +254,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
268254
featurizer = learn_to_pick.PickBestFeaturizer(
269255
auto_embed=True, model=MockEncoderReturnsList()
270256
)
271-
pick = learn_to_pick.PickBest.create(
272-
llm=fake_llm_caller, featurizer=featurizer
273-
)
257+
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)
274258

275259
str1 = "0"
276260
str2 = "1"

tests/unit_tests/test_pick_best_text_embedder.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
257257
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
258258

259259
named_actions = {
260-
"action1": [
261-
{"a": str1, "b": rl_chain.Embed(str1)},
262-
str2,
263-
rl_chain.Embed(str3),
264-
]
260+
"action1": [{"a": str1, "b": rl_chain.Embed(str1)}, str2, rl_chain.Embed(str3)]
265261
}
266262
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
267263
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
@@ -296,10 +292,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
296292
rl_chain.EmbedAndKeep(str3),
297293
]
298294
}
299-
context = {
300-
"context1": ctx_str_1,
301-
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
302-
}
295+
context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)}
303296
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
304297

305298
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)

tests/unit_tests/test_rl_loop_base_embedder.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ def test_context_w_namespace_w_some_emb() -> None:
8383
== expected
8484
)
8585
expected_embed_and_keep = [
86-
{
87-
"test_namespace": str1,
88-
"test_namespace2": str2 + " " + encoded_str2,
89-
}
86+
{"test_namespace": str1, "test_namespace2": str2 + " " + encoded_str2}
9087
]
9188
assert (
9289
base.embed(
@@ -337,18 +334,9 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
337334
== expected
338335
)
339336
expected_embed_and_keep = [
340-
{
341-
"test_namespace": str1 + " " + encoded_str1,
342-
"test_namespace2": str1,
343-
},
344-
{
345-
"test_namespace": str2 + " " + encoded_str2,
346-
"test_namespace2": str2,
347-
},
348-
{
349-
"test_namespace": str3 + " " + encoded_str3,
350-
"test_namespace2": str3,
351-
},
337+
{"test_namespace": str1 + " " + encoded_str1, "test_namespace2": str1},
338+
{"test_namespace": str2 + " " + encoded_str2, "test_namespace2": str2},
339+
{"test_namespace": str3 + " " + encoded_str3, "test_namespace2": str3},
352340
]
353341
assert (
354342
base.embed(

0 commit comments

Comments
 (0)