Skip to content

Commit 0e96cd1

Browse files
Fixed #3005: Processes both formats of model_args: string and dictionay (#3097)
* git push --force correctly processes both formats of model_args: string and dictionary both * exctract to function for better test * nit --------- Co-authored-by: Baber <[email protected]>
1 parent 6e91fdc commit 0e96cd1

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

scripts/zeno_visualize.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import re
66
from pathlib import Path
7+
from typing import Union
78

89
import pandas as pd
910
from zeno_client import ZenoClient, ZenoMetric
@@ -35,6 +36,22 @@ def parse_args():
3536
return parser.parse_args()
3637

3738

39+
def sanitize_string(model_args_raw: Union[str, dict]) -> str:
40+
"""Sanitize the model_args string or dict"""
41+
# Convert to string if it's a dictionary
42+
model_args_str = (
43+
json.dumps(model_args_raw)
44+
if isinstance(model_args_raw, dict)
45+
else model_args_raw
46+
)
47+
# Apply the sanitization
48+
return re.sub(
49+
r"[\"<>:/|\\?*\[\]]+",
50+
"__",
51+
model_args_str,
52+
)
53+
54+
3855
def main():
3956
"""Upload the results of your benchmark tasks to the Zeno AI evaluation platform.
4057
@@ -87,13 +104,16 @@ def main():
87104
latest_sample_results = get_latest_filename(
88105
[Path(f).name for f in model_sample_filenames if task in f]
89106
)
90-
model_args = re.sub(
91-
r"[\"<>:/\|\\?\*\[\]]+",
92-
"__",
107+
# Load the model_args, which can be either a string or a dictionary
108+
model_args = sanitize_string(
93109
json.load(
94-
open(Path(args.data_path, model, latest_results), encoding="utf-8")
95-
)["config"]["model_args"],
110+
open(
111+
Path(args.data_path, model, latest_results),
112+
encoding="utf-8",
113+
)
114+
)["config"]["model_args"]
96115
)
116+
97117
print(model_args)
98118
data = []
99119
with open(
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import json
2+
import re
3+
4+
import pytest
5+
6+
from scripts.zeno_visualize import sanitize_string
7+
8+
9+
@pytest.skip("requires zeno_client dependency")
10+
def test_zeno_sanitize_string():
11+
"""
12+
Test that the model_args handling logic in zeno_visualize.py properly handles
13+
different model_args formats (string and dictionary).
14+
"""
15+
16+
# Define the process_model_args function that replicates the fixed logic in zeno_visualize.py
17+
# Test case 1: model_args as a string
18+
string_model_args = "pretrained=EleutherAI/pythia-160m,dtype=float32"
19+
result_string = sanitize_string(string_model_args)
20+
expected_string = re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", string_model_args)
21+
22+
# Test case 2: model_args as a dictionary
23+
dict_model_args = {"pretrained": "EleutherAI/pythia-160m", "dtype": "float32"}
24+
result_dict = sanitize_string(dict_model_args)
25+
expected_dict = re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", json.dumps(dict_model_args))
26+
27+
# Verify the results
28+
assert result_string == expected_string
29+
assert result_dict == expected_dict
30+
31+
# Also test that the sanitization works as expected
32+
assert ":" not in result_string # No colons in sanitized output
33+
assert ":" not in result_dict # No colons in sanitized output
34+
assert "/" not in result_dict # No slashes in sanitized output
35+
assert "<" not in result_dict # No angle brackets in sanitized output
36+
37+
38+
if __name__ == "__main__":
39+
test_zeno_sanitize_string()
40+
print("All tests passed.")

0 commit comments

Comments
 (0)