Skip to content

Commit e7d824d

Browse files
committed
Add ConsistencyCheckError exception and handle consistency checks in benchmarks.
1 parent 4d166fc commit e7d824d

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

src/modelbench/cli.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
import modelgauge.annotators.cheval.registration # noqa: F401
2525
from modelbench.benchmark_runner import BenchmarkRun, BenchmarkRunner, JsonRunTracker, TqdmRunTracker
2626
from modelbench.benchmarks import GeneralPurposeAiChatBenchmarkV1, SecurityBenchmark
27-
from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results
27+
from modelbench.consistency_checker import (
28+
ConsistencyCheckError,
29+
ConsistencyChecker,
30+
summarize_consistency_check_results,
31+
)
2832
from modelbench.record import dump_json
2933
from modelbench.standards import Standards
3034
from modelgauge.config import load_secrets_from_config, write_default_config
@@ -188,7 +192,11 @@ def general_benchmark(
188192
sut = make_sut(sut_uid)
189193
benchmark = GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)
190194
check_benchmark(benchmark)
191-
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user)
195+
try:
196+
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user)
197+
except ConsistencyCheckError as e:
198+
echo(termcolor.colored(str(e), "red"), err=True)
199+
sys.exit(e.EXIT_CODE)
192200

193201

194202
@benchmark.command("security", help="run a security benchmark")
@@ -211,7 +219,11 @@ def security_benchmark(
211219
sut = make_sut(sut_uid)
212220
benchmark = SecurityBenchmark(locale, prompt_set, evaluator=evaluator)
213221
check_benchmark(benchmark)
214-
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user)
222+
try:
223+
run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, output_dir, run_uid, user)
224+
except ConsistencyCheckError as e:
225+
echo(termcolor.colored(str(e), "red"), err=True)
226+
sys.exit(e.EXIT_CODE)
215227

216228

217229
def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, run_path, outputdir, run_uid, user):
@@ -233,7 +245,9 @@ def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, ru
233245
annotation_records.write(json.dumps(annotations))
234246
print(f"Wrote annotations for {benchmark.uid} to {annotation_path}.")
235247

236-
run_consistency_check(run.journal_path, verbose=True)
248+
consistent = run_consistency_check(run.journal_path, verbose=True)
249+
if not consistent:
250+
raise ConsistencyCheckError("Consistency check failed for the benchmark run.")
237251

238252

239253
@cli.command(

src/modelbench/consistency_checker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,7 @@ def summarize_consistency_check_results(checkers: List[ConsistencyChecker]):
620620

621621
console = Console()
622622
console.print(table)
623+
624+
625+
class ConsistencyCheckError(Exception):
626+
EXIT_CODE = 2

tests/modelbench_tests/test_run.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SecurityScore,
2222
)
2323
from modelbench.cli import cli
24+
from modelbench.consistency_checker import ConsistencyCheckError
2425
from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards
2526
from modelbench.scoring import ValueEstimate
2627
from modelbench.standards import NoStandardsFileError, OverwriteStandardsFileError
@@ -318,6 +319,7 @@ def invoke(command, args=None, **kwargs):
318319
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
319320
def test_benchmark_basic_run_produces_json(
320321
self,
322+
monkeypatch,
321323
runner,
322324
mock_run_benchmarks,
323325
mock_score_benchmarks,
@@ -346,6 +348,7 @@ def test_benchmark_basic_run_produces_json(
346348
sut_uid,
347349
*benchmark_options,
348350
]
351+
monkeypatch.setattr(modelbench.cli, "run_consistency_check", lambda *args, **kwargs: True)
349352
result = runner(
350353
cli,
351354
command_options,
@@ -411,6 +414,7 @@ def test_benchmark_multiple_suts_produces_json(
411414

412415
mock = MagicMock(return_value=[self.mock_score(sut_uid, benchmark), self.mock_score("demo_yes_no", benchmark)])
413416
monkeypatch.setattr(modelbench.cli, "score_benchmarks", mock)
417+
monkeypatch.setattr(modelbench.cli, "run_consistency_check", lambda *args, **kwargs: True)
414418

415419
result = runner(
416420
cli,
@@ -430,6 +434,25 @@ def test_benchmark_multiple_suts_produces_json(
430434
assert result.exit_code == 0
431435
assert (run_dir / "records" / f"benchmark_record-{benchmark.uid}.json").exists
432436

437+
@pytest.mark.parametrize("benchmark_type", ["general", "security"])
438+
def test_general_benchmark_exits_when_consistency_fails(self, runner, benchmark_type, sut, monkeypatch):
439+
fail_run = MagicMock(side_effect=ConsistencyCheckError("consistency failed"))
440+
monkeypatch.setattr(modelbench.cli, "run_and_report_benchmark", fail_run)
441+
442+
result = runner(
443+
cli,
444+
[
445+
"benchmark",
446+
benchmark_type,
447+
"--sut",
448+
sut.uid,
449+
],
450+
catch_exceptions=False,
451+
)
452+
453+
fail_run.assert_called_once()
454+
assert result.exit_code == ConsistencyCheckError.EXIT_CODE
455+
433456
def test_benchmark_bad_sut_errors_out(self, runner):
434457
benchmark_options = ["--version", "1.1"]
435458
benchmark_options.extend(["--locale", "en_us"])

0 commit comments

Comments
 (0)