Skip to content

Commit 58f6b7a

Browse files
authored
Enforce SUT uid params (#1495)
* Document that UIDs should be quoted * different types of sut data * Pass config kwargs to factories
1 parent 080c865 commit 58f6b7a

File tree

7 files changed

+100
-35
lines changed

7 files changed

+100
-35
lines changed

src/modelgauge/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,12 @@ def list_secrets() -> None:
132132

133133
@cli.command()
134134
@LOCAL_PLUGIN_DIR_OPTION
135-
@click.option("--sut", "-s", help="Which SUT to run.", required=True)
135+
@click.option(
136+
"--sut",
137+
"-s",
138+
help="Which SUT to run. Please quote the value to ensure that dynamic parameterizations are included.",
139+
required=True,
140+
)
136141
@sut_options_options
137142
@click.option("--prompt", help="The full text to send to the SUT.", required=True)
138143
def run_sut(

src/modelgauge/dynamic_sut_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ def get_secrets(self) -> list[InjectSecret]:
4141

4242
@abstractmethod
4343
def make_sut(self, sut_definition: SUTDefinition) -> SUT:
44+
"""Factories that handle special SUT config parameters (e.g. moderated, reasoning) must accept them as kwargs."""
4445
pass

src/modelgauge/sut_definition.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,19 @@ def name_for_label(self, label):
2323
raise (ValueError(f"for static elements, {label} must match {self.label}"))
2424

2525

26-
class PrefixSUTSpecificationElement(SUTSpecificationElement):
26+
class SUTMetadataElement(SUTSpecificationElement):
27+
"""Core SUT information."""
28+
29+
pass
30+
31+
32+
class SUTConfigElement(SUTSpecificationElement):
33+
"""Optional configuration data passed to a SUT on initialization."""
34+
35+
required = False
36+
37+
38+
class PrefixSUTSpecificationElement(SUTConfigElement):
2739

2840
def matches(self, field_name):
2941
return field_name.startswith(self.name)
@@ -37,15 +49,14 @@ class SUTSpecification:
3749

3850
def __init__(self):
3951
fields = [
40-
SUTSpecificationElement("model", "m", str, True),
41-
SUTSpecificationElement("driver", "d", str, True),
42-
SUTSpecificationElement("maker", "mk", str),
43-
SUTSpecificationElement("provider", "pr", str),
44-
SUTSpecificationElement("display_name", "dn", str),
45-
SUTSpecificationElement("reasoning", "reas", bool),
46-
SUTSpecificationElement("moderated", "mod", bool),
47-
SUTSpecificationElement("date", "dt", str),
48-
SUTSpecificationElement("base_url", "url", str),
52+
SUTMetadataElement("model", "m", str, True),
53+
SUTMetadataElement("driver", "d", str, True),
54+
SUTMetadataElement("maker", "mk", str),
55+
SUTMetadataElement("provider", "pr", str),
56+
SUTMetadataElement("date", "dt", str),
57+
SUTConfigElement("reasoning", "reas", bool),
58+
SUTConfigElement("moderated", "mod", bool),
59+
SUTConfigElement("base_url", "url", str),
4960
]
5061

5162
self._wildcard_fields = [PrefixSUTSpecificationElement("vllm-", "vllm", str)]
@@ -59,9 +70,12 @@ def knows(self, name: str):
5970
def requires(self, name: str):
6071
return self.knows(name) and self._fields_by_name[name].required
6172

62-
def validate(self, data: dict) -> bool:
73+
def validate(self, metadata: dict, config_data: dict) -> bool:
6374
for field_spec in self._fields_by_name.values():
64-
value = data.get(field_spec.name, None)
75+
if isinstance(field_spec, SUTMetadataElement):
76+
value = metadata.get(field_spec.name, None)
77+
else:
78+
value = config_data.get(field_spec.name, None)
6579
if field_spec.required and value is None:
6680
raise ValueError(f"Field {field_spec.name} is required.")
6781
if value is not None and not isinstance(value, field_spec.value_type):
@@ -79,26 +93,34 @@ def element_for_label(self, label: str):
7993
return element
8094
return None
8195

96+
def element_for_name(self, name: str):
97+
if name in self._fields_by_name:
98+
return self._fields_by_name[name]
99+
for element in self._wildcard_fields:
100+
if element.matches(name):
101+
return element
102+
return None
103+
82104

83105
DEFINITION_VALUE_TYPES = Union[str, int, float, bool, None]
84106

85107

86108
class SUTDefinition:
87109
"""The data in a SUT configuration file or JSON blob"""
88110

89-
_data: dict[str, DEFINITION_VALUE_TYPES]
111+
_metadata: dict[str, DEFINITION_VALUE_TYPES]
90112

91113
def __init__(self, data=None, **kwargs):
92114
self.spec = SUTSpecification()
93-
self._data = {}
115+
self._metadata = {} # Core SUT information.
116+
self.config_data = {} # Everything that comes after ";"
94117

95118
if data:
96119
for k, v in data.items():
97120
self._add(k, v)
98121
for k, v in kwargs.items():
99122
self._add(k, v)
100-
if not self.spec.validate(self._data):
101-
raise ValueError(f"Invalid data: {self._data}")
123+
self.spec.validate(self._metadata, self.config_data)
102124

103125
generator = SUTUIDGenerator(self)
104126
self.uid = generator.uid
@@ -110,31 +132,37 @@ def __str__(self):
110132
def _add(self, key: str, value: DEFINITION_VALUE_TYPES):
111133
if isinstance(value, str):
112134
value = value.strip()
113-
if self.spec.knows(key):
114-
self._data[key] = value
115-
else:
135+
if not self.spec.knows(key):
116136
raise ValueError(f"Don't know what to do with {key}")
137+
spec_element = self.spec.element_for_name(key)
138+
if isinstance(spec_element, SUTMetadataElement):
139+
self._metadata[key] = value
140+
elif isinstance(spec_element, SUTConfigElement):
141+
self.config_data[key] = value
142+
else:
143+
raise ValueError(f"Unknown spec element type {spec_element} for {key}")
117144

118145
def get(self, field: str, default=None) -> DEFINITION_VALUE_TYPES:
119-
return self._data.get(field, default)
146+
return self._metadata.get(field, self.config_data.get(field, default))
120147

121148
def get_matching(self, label: str) -> Mapping[str, DEFINITION_VALUE_TYPES] | None:
122149
element = self.spec.element_for_label(label)
123150
if not element:
124151
return None
125152
result = {}
126-
for k, v in self._data.items():
127-
if element.matches(k):
128-
result[k] = v
153+
for items in (self._metadata.items(), self.config_data.items()):
154+
for k, v in items:
155+
if element.matches(k):
156+
result[k] = v
129157
return result
130158

131159
def to_dynamic_sut_metadata(self) -> DynamicSUTMetadata:
132160
return DynamicSUTMetadata(
133-
model=self._data["model"], # type: ignore
134-
driver=self._data["driver"], # type: ignore
135-
maker=self._data.get("maker", None), # type: ignore
136-
provider=self._data.get("provider", None), # type: ignore
137-
date=self._data.get("date", None), # type: ignore
161+
model=self._metadata["model"], # type: ignore
162+
driver=self._metadata["driver"], # type: ignore
163+
maker=self._metadata.get("maker", None), # type: ignore
164+
provider=self._metadata.get("provider", None), # type: ignore
165+
date=self._metadata.get("date", None), # type: ignore
138166
)
139167

140168
def external_model_name(self) -> str:
@@ -231,7 +259,6 @@ class SUTUIDGenerator:
231259
order = (
232260
"moderated",
233261
"reasoning",
234-
"display_name",
235262
"base_url", # for OpenAI-compatible SUTs
236263
)
237264
field_separator = RICH_UID_FIELD_SEPARATOR

src/modelgauge/sut_factory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class SUTNotFoundException(Exception):
2121
pass
2222

2323

24+
class IncompatibleSUTParamsError(Exception):
25+
pass
26+
27+
2428
class SUTType(Enum):
2529
DYNAMIC = "dynamic"
2630
KNOWN = "known"
@@ -185,7 +189,13 @@ def _make_dynamic_sut(self, uid: str) -> SUT:
185189
factory = self.dynamic_sut_factories.get(sut_definition.get("driver")) # type: ignore
186190
if not factory:
187191
raise UnknownSUTMakerError(f'Don\'t know how to make dynamic sut "{uid}"')
188-
return factory.make_sut(sut_definition)
192+
try:
193+
sut = factory.make_sut(sut_definition, **sut_definition.config_data)
194+
except TypeError:
195+
raise IncompatibleSUTParamsError(
196+
f"The {factory.__class__.__name__} factory cannot handle some dynamic SUT parameters specified in the uid: {sut_definition.config_data}."
197+
)
198+
return sut
189199

190200
def keys(self) -> list[str]:
191201
"""Mimic the registry interface."""

tests/modelgauge_tests/test_dynamic_sut_factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def make_sut(self, sut_definition: SUTDefinition):
1515
return FakeSUT(sut_definition.dynamic_uid)
1616

1717

18+
class FakeDynamicFactoryHandlesMod(FakeDynamicFactory):
19+
def make_sut(self, sut_definition: SUTDefinition, moderated: bool = False):
20+
return FakeSUT(sut_definition.dynamic_uid)
21+
22+
1823
def test_injected_secrets():
1924
factory = FakeDynamicFactory(
2025
{"some-scope": {"some-key": "some-value"}, "optional-scope": {"optional-key": "optional-value"}}

tests/modelgauge_tests/test_sut_definition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ def test_to_dynamic_sut_metadata():
3232

3333

3434
def test_parse_rich_sut_uid():
35-
uid = "google/gemma-3-27b-it:nebius:hfrelay;url=https://example.com/"
35+
uid = "google/gemma-3-27b-it:nebius:hfrelay;reas=y;url=https://example.com/"
3636
definition = SUTDefinition.parse(uid)
3737
assert definition.get("model") == "gemma-3-27b-it"
3838
assert definition.get("maker") == "google"
3939
assert definition.get("driver") == "hfrelay"
4040
assert definition.get("provider") == "nebius"
4141
assert definition.get("base_url") == "https://example.com/"
42+
assert definition.get("reasoning") is True
43+
44+
assert definition.uid == uid
4245

4346

4447
def test_vllm_parameters():

tests/modelgauge_tests/test_sut_factory.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from modelgauge.dynamic_sut_factory import UnknownSUTMakerError
55
from modelgauge.instance_factory import InstanceFactory
66
from modelgauge.sut import SUT
7-
from modelgauge.sut_factory import SUTFactory, SUTNotFoundException, SUTType
7+
from modelgauge.sut_factory import IncompatibleSUTParamsError, SUTFactory, SUTNotFoundException, SUTType
88
from modelgauge_tests.fake_sut import FakeSUT
9-
from modelgauge_tests.test_dynamic_sut_factory import FakeDynamicFactory
9+
from modelgauge_tests.test_dynamic_sut_factory import FakeDynamicFactory, FakeDynamicFactoryHandlesMod
1010

1111
KNOWN_UID = "known"
1212
UNKNOWN_UID = "pleasedontregisterasutwiththisuid"
@@ -26,7 +26,11 @@ def sut_factory():
2626
def sut_factory_dynamic():
2727
"""SUT factory that patches the dynamic SUT factories."""
2828
registry = InstanceFactory[SUT]()
29-
dynamic_factories = {"driver1": FakeDynamicFactory({}), "driver2": FakeDynamicFactory({})}
29+
dynamic_factories = {
30+
"driver1": FakeDynamicFactory({}),
31+
"driver2": FakeDynamicFactory({}),
32+
"mod_driver": FakeDynamicFactoryHandlesMod({}),
33+
}
3034
with patch(
3135
"modelgauge.sut_factory.SUTFactory._load_dynamic_sut_factories",
3236
return_value=dynamic_factories,
@@ -67,6 +71,16 @@ def test_make_instance_dynamic_unknown_driver(sut_factory_dynamic):
6771
sut_factory_dynamic.make_instance("google/gemma:unknown", secrets={})
6872

6973

74+
def test_make_instance_dynamic_with_params(sut_factory_dynamic):
75+
sut = sut_factory_dynamic.make_instance("google/gemma:mod_driver;mod=y", secrets={})
76+
assert isinstance(sut, FakeSUT)
77+
78+
79+
def test_make_instance_dynamic_incompatible_params(sut_factory_dynamic):
80+
with pytest.raises(IncompatibleSUTParamsError):
81+
sut_factory_dynamic.make_instance("google/gemma:driver1;mod=y", secrets={})
82+
83+
7084
def test_make_instance_unknown_type(sut_factory):
7185
with pytest.raises(SUTNotFoundException):
7286
sut_factory.make_instance(UNKNOWN_UID, secrets={})

0 commit comments

Comments
 (0)