Skip to content

Commit 6c7ad29

Browse files
Keshav Ramji Keshav.Ramji@ibm.comKeshav Ramji Keshav.Ramji@ibm.com
authored andcommitted
Pre-commit fixes
1 parent c9bc4b5 commit 6c7ad29

File tree

4 files changed

+60
-34
lines changed

4 files changed

+60
-34
lines changed

cli/eval/commands.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,22 @@ def eval_run(
99
),
1010
backend: str = typer.Option("ollama", "--backend", "-b", help="Generation backend"),
1111
model: str = typer.Option(None, "--model", help="Generation model name"),
12-
max_gen_tokens: int = typer.Option(256, "--max-gen-tokens", help="Max tokens to generate for responses"),
12+
max_gen_tokens: int = typer.Option(
13+
256, "--max-gen-tokens", help="Max tokens to generate for responses"
14+
),
1315
judge_backend: str = typer.Option(
1416
None, "--judge-backend", "-jb", help="Judge backend"
1517
),
1618
judge_model: str = typer.Option(None, "--judge-model", help="Judge model name"),
17-
max_judge_tokens: int = typer.Option(256, "--max-judge-tokens", help="Max tokens for the judge model's judgement."),
19+
max_judge_tokens: int = typer.Option(
20+
256, "--max-judge-tokens", help="Max tokens for the judge model's judgement."
21+
),
1822
output_path: str = typer.Option(
1923
"eval_results", "--output-path", "-o", help="Output path for results"
2024
),
2125
output_format: str = typer.Option(
2226
"json", "--output-format", help="Either json or jsonl format for results"
2327
),
24-
verbose: bool = typer.Option(False, "--verbose", "-v"),
2528
continue_on_error: bool = typer.Option(True, "--continue-on-error"),
2629
):
2730
from cli.eval.runner import run_evaluations
@@ -36,7 +39,6 @@ def eval_run(
3639
max_judge_tokens=max_judge_tokens,
3740
output_path=output_path,
3841
output_format=output_format,
39-
verbose=verbose,
4042
continue_on_error=continue_on_error,
4143
)
4244

cli/eval/runner.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
model_output: str,
2424
validation_passed: bool,
2525
score: int,
26-
validation_reason: str, # add input_id
26+
validation_reason: str, # add input_id
2727
):
2828
self.input_text = input_text
2929
self.model_output = model_output
@@ -74,7 +74,9 @@ def pass_rate(self) -> float:
7474
return self.passed_count / self.total_count if self.total_count > 0 else 0.0
7575

7676

77-
def create_session(backend: str, model: str | None, max_tokens: int | None) -> mellea.MelleaSession:
77+
def create_session(
78+
backend: str, model: str | None, max_tokens: int | None
79+
) -> mellea.MelleaSession:
7880
"""Create a mellea session with the specified backend and model."""
7981

8082
model_id = None
@@ -96,35 +98,40 @@ def create_session(backend: str, model: str | None, max_tokens: int | None) -> m
9698
from mellea.backends.ollama import OllamaModelBackend
9799

98100
backend_instance = OllamaModelBackend(
99-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
101+
model_id=model_id,
102+
model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
100103
)
101104

102105
elif backend_lower == "openai":
103106
from mellea.backends.openai import OpenAIBackend
104107

105108
backend_instance = OpenAIBackend(
106-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
109+
model_id=model_id,
110+
model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
107111
)
108112

109113
elif backend_lower in ["hf", "huggingface"]:
110114
from mellea.backends.huggingface import LocalHFBackend
111115

112116
backend_instance = LocalHFBackend(
113-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
117+
model_id=model_id,
118+
model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
114119
)
115120

116121
elif backend_lower == "watsonx":
117122
from mellea.backends.watsonx import WatsonxAIBackend
118123

119124
backend_instance = WatsonxAIBackend(
120-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
125+
model_id=model_id,
126+
model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
121127
)
122128

123129
elif backend_lower == "litellm":
124130
from mellea.backends.litellm import LiteLLMBackend
125131

126132
backend_instance = LiteLLMBackend(
127-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
133+
model_id=model_id,
134+
model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
128135
)
129136

130137
else:
@@ -135,9 +142,7 @@ def create_session(backend: str, model: str | None, max_tokens: int | None) -> m
135142
# create session with backend instance
136143
from mellea.stdlib.base import SimpleContext
137144

138-
session = mellea.MelleaSession(
139-
backend=backend_instance, ctx=SimpleContext()
140-
)
145+
session = mellea.MelleaSession(backend=backend_instance, ctx=SimpleContext())
141146
return session
142147

143148
except Exception as e:
@@ -157,7 +162,6 @@ def run_evaluations(
157162
max_judge_tokens: int | None,
158163
output_path: str,
159164
output_format: str,
160-
verbose: bool,
161165
continue_on_error: bool,
162166
):
163167
"""Run all 'unit test' evaluations"""
@@ -176,14 +180,16 @@ def run_evaluations(
176180
return
177181

178182
console.print(f"Total test evals to run: {len(all_test_evals)}")
179-
total_inputs = sum(len(te.inputs) for te in all_test_evals)
183+
total_inputs = sum(len(test_eval.inputs) for test_eval in all_test_evals)
180184
console.print(f"Total inputs to run: {total_inputs}")
181185

182186
console.print(f"Generation model: {model}")
183187
console.print(f"Judge model: {judge_model}")
184188

185189
m = create_session(backend=backend, model=model, max_tokens=max_gen_tokens)
186-
judge_session = create_session(backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens)
190+
judge_session = create_session(
191+
backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens
192+
)
187193

188194
all_results = []
189195

@@ -234,12 +240,14 @@ def execute_test_eval(
234240
result: ModelOutputThunk = generation_session.act(input_text)
235241
model_output = str(result)
236242

237-
judge_session.ctx = judge_session.ctx.add(result)
238-
239-
targets_for_input = (test_eval.targets[idx] if idx < len(test_eval.targets) else [])
243+
targets_for_input = (
244+
test_eval.targets[idx] if idx < len(test_eval.targets) else []
245+
)
240246

241247
# query the judge
242-
judge_prompt = create_judge_requirement(test_eval, input_text, model_output, targets_for_input)
248+
judge_prompt = create_judge_requirement(
249+
test_eval, input_text, model_output, targets_for_input
250+
)
243251
judge_output_thunk = judge_session.act(judge_prompt)
244252
judge_output = str(judge_output_thunk)
245253
score, justification = parse_judge_output(judge_output)
@@ -263,7 +271,10 @@ def execute_test_eval(
263271

264272

265273
def create_judge_requirement(
266-
test_eval: TestBasedEval, input_text: str, model_output: str, targets_for_input: list[str]
274+
test_eval: TestBasedEval,
275+
input_text: str,
276+
model_output: str,
277+
targets_for_input: list[str],
267278
):
268279
"""Create judge requirement description"""
269280

mellea/stdlib/reqlib/md.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ def as_markdown_list(ctx: Context) -> list[str] | None:
1414
raw_output = ctx.last_output()
1515
assert raw_output is not None
1616
try:
17-
parsed = mistletoe.Document(raw_output.value) # type: ignore
18-
for child in parsed.children: # type: ignore
17+
assert raw_output.value is not None
18+
parsed = mistletoe.Document(raw_output.value)
19+
assert parsed.children is not None
20+
children = list(parsed.children)
21+
for child in children:
1922
if type(child) is not mistletoe.block_token.List:
2023
return None
21-
for item in child.children: # type: ignore
24+
assert child.children is not None
25+
for item in child.children:
2226
xs.append(mistletoe.base_renderer.BaseRenderer().render(item))
2327
return xs
2428
except Exception:
@@ -44,10 +48,13 @@ def _md_table(ctx: Context):
4448
raw_output = ctx.last_output()
4549
assert raw_output is not None
4650
try:
47-
parsed = mistletoe.Document(raw_output.value) # type: ignore
48-
if len(parsed.children) != 1: # type: ignore
51+
assert raw_output.value is not None
52+
parsed = mistletoe.Document(raw_output.value)
53+
assert parsed.children is not None
54+
children = list(parsed.children)
55+
if len(children) != 1:
4956
return False
50-
return type(parsed.children[0]) is mistletoe.block_token.Table # type: ignore
57+
return type(children[0]) is mistletoe.block_token.Table
5158
except Exception:
5259
return False
5360

mellea/stdlib/test_based_eval.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
inputs: list[str],
1919
targets: list[list[str]] | None = None, # can be optional
2020
test_id: str | None = None,
21-
input_ids: list[str] | None = None
21+
input_ids: list[str] | None = None,
2222
):
2323
"""Initialize TestBasedEval (for a single unit test)."""
2424
self.source = source
@@ -61,7 +61,7 @@ def from_json_file(cls, filepath: str) -> list["TestBasedEval"]:
6161
"""Load test evaluations from json/jsonl file, return list of TestBasedEval instances, one per 'unit test'."""
6262
path = Path(filepath)
6363

64-
with path.open('r') as f:
64+
with path.open("r") as f:
6565
data = json.load(f)
6666

6767
if not isinstance(data, list):
@@ -77,12 +77,18 @@ def from_json_file(cls, filepath: str) -> list["TestBasedEval"]:
7777

7878
for example in examples:
7979
input_messages = example.get("input", [])
80-
user_messages = [msg for msg in input_messages if msg.get("role") == "user"]
80+
user_messages = [
81+
msg for msg in input_messages if msg.get("role") == "user"
82+
]
8183
if user_messages:
8284
inputs.append(user_messages[-1].get("content", ""))
8385

8486
target_messages = example.get("targets", [])
85-
targets_for_input = [msg.get("content", "") for msg in target_messages if msg.get("role") == "assistant"]
87+
targets_for_input = [
88+
msg.get("content", "")
89+
for msg in target_messages
90+
if msg.get("role") == "assistant"
91+
]
8692
targets.append(targets_for_input)
8793

8894
input_ids.append(example.get("input_id", ""))
@@ -94,8 +100,8 @@ def from_json_file(cls, filepath: str) -> list["TestBasedEval"]:
94100
inputs=inputs,
95101
targets=targets,
96102
test_id=test_data.get("id", ""),
97-
input_ids=input_ids
103+
input_ids=input_ids,
98104
)
99105
test_evals.append(test_eval)
100106

101-
return test_evals
107+
return test_evals

0 commit comments

Comments
 (0)