Skip to content

Commit 2ef1bc5

Browse files
authored
Improve test (#674)
The idea is to read the expected results before running the model so that if they are not fetched with Git LFS, the user receives an error before running the model on the data.
1 parent 4ab19e8 commit 2ef1bc5

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/slow_tests/test_accelerate_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ def test_accelerate_model_prediction(tests: list[ModelInput]):
8686
"""Evaluates a model on a full task - is parametrized using pytest_generate_test"""
8787
model_args, get_predictions = tests
8888

89-
predictions = get_predictions()["results"]
90-
9189
# Load the reference results
9290
with open(model_args["results_file"], "r") as f:
9391
reference_results = json.load(f)["results"]
9492

9593
# Change the key names, replace '|' with ':'
9694
reference_results = {k.replace("|", ":"): v for k, v in reference_results.items()}
9795

96+
# Get the predictions
97+
predictions = get_predictions()["results"]
98+
9899
# Convert defaultdict values to regular dict for comparison
99100
predictions_dict = {k: dict(v) if hasattr(v, "default_factory") else v for k, v in predictions.items()}
100101

102+
# Compare the predictions with the reference results
101103
diff = DeepDiff(reference_results, predictions_dict, ignore_numeric_type_changes=True, math_epsilon=0.05)
102104

103105
assert diff == {}, f"Differences found: {diff}"

0 commit comments

Comments
 (0)