Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/modelgauge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def list_secrets() -> None:

@cli.command()
@LOCAL_PLUGIN_DIR_OPTION
@click.option("--sut", "-s", help="Which SUT to run.", required=True)
@click.option(
"--sut",
"-s",
help="Which SUT to run. Please quote the value to ensure that dynamic parameterizations are included.",
required=True,
)
@sut_options_options
@click.option("--prompt", help="The full text to send to the SUT.", required=True)
def run_sut(
Expand Down
1 change: 1 addition & 0 deletions src/modelgauge/dynamic_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ def get_secrets(self) -> list[InjectSecret]:

@abstractmethod
def make_sut(self, sut_definition: SUTDefinition) -> SUT:
"""Factories that handle special SUT config parameters (e.g. moderated, reasoning) must accept them as kwargs."""
pass
85 changes: 56 additions & 29 deletions src/modelgauge/sut_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,19 @@ def name_for_label(self, label):
raise (ValueError(f"for static elements, {label} must match {self.label}"))


class PrefixSUTSpecificationElement(SUTSpecificationElement):
class SUTMetadataElement(SUTSpecificationElement):
"""Core SUT information."""

pass


class SUTConfigElement(SUTSpecificationElement):
"""Optional configuration data passed to a SUT on initialization."""

required = False


class PrefixSUTSpecificationElement(SUTConfigElement):

def matches(self, field_name):
return field_name.startswith(self.name)
Expand All @@ -37,15 +49,14 @@ class SUTSpecification:

def __init__(self):
fields = [
SUTSpecificationElement("model", "m", str, True),
SUTSpecificationElement("driver", "d", str, True),
SUTSpecificationElement("maker", "mk", str),
SUTSpecificationElement("provider", "pr", str),
SUTSpecificationElement("display_name", "dn", str),
SUTSpecificationElement("reasoning", "reas", bool),
SUTSpecificationElement("moderated", "mod", bool),
SUTSpecificationElement("date", "dt", str),
SUTSpecificationElement("base_url", "url", str),
SUTMetadataElement("model", "m", str, True),
SUTMetadataElement("driver", "d", str, True),
SUTMetadataElement("maker", "mk", str),
SUTMetadataElement("provider", "pr", str),
SUTMetadataElement("date", "dt", str),
SUTConfigElement("reasoning", "reas", bool),
SUTConfigElement("moderated", "mod", bool),
SUTConfigElement("base_url", "url", str),
]

self._wildcard_fields = [PrefixSUTSpecificationElement("vllm-", "vllm", str)]
Expand All @@ -59,9 +70,12 @@ def knows(self, name: str):
def requires(self, name: str):
return self.knows(name) and self._fields_by_name[name].required

def validate(self, data: dict) -> bool:
def validate(self, metadata: dict, config_data: dict) -> bool:
for field_spec in self._fields_by_name.values():
value = data.get(field_spec.name, None)
if isinstance(field_spec, SUTMetadataElement):
value = metadata.get(field_spec.name, None)
else:
value = config_data.get(field_spec.name, None)
if field_spec.required and value is None:
raise ValueError(f"Field {field_spec.name} is required.")
if value is not None and not isinstance(value, field_spec.value_type):
Expand All @@ -79,26 +93,34 @@ def element_for_label(self, label: str):
return element
return None

def element_for_name(self, name: str):
if name in self._fields_by_name:
return self._fields_by_name[name]
for element in self._wildcard_fields:
if element.matches(name):
return element
return None


DEFINITION_VALUE_TYPES = Union[str, int, float, bool, None]


class SUTDefinition:
"""The data in a SUT configuration file or JSON blob"""

_data: dict[str, DEFINITION_VALUE_TYPES]
_metadata: dict[str, DEFINITION_VALUE_TYPES]

def __init__(self, data=None, **kwargs):
self.spec = SUTSpecification()
self._data = {}
self._metadata = {} # Core SUT information.
self.config_data = {} # Everything that comes after ";"

if data:
for k, v in data.items():
self._add(k, v)
for k, v in kwargs.items():
self._add(k, v)
if not self.spec.validate(self._data):
raise ValueError(f"Invalid data: {self._data}")
self.spec.validate(self._metadata, self.config_data)

generator = SUTUIDGenerator(self)
self.uid = generator.uid
Expand All @@ -110,31 +132,37 @@ def __str__(self):
def _add(self, key: str, value: DEFINITION_VALUE_TYPES):
if isinstance(value, str):
value = value.strip()
if self.spec.knows(key):
self._data[key] = value
else:
if not self.spec.knows(key):
raise ValueError(f"Don't know what to do with {key}")
spec_element = self.spec.element_for_name(key)
if isinstance(spec_element, SUTMetadataElement):
self._metadata[key] = value
elif isinstance(spec_element, SUTConfigElement):
self.config_data[key] = value
else:
raise ValueError(f"Unknown spec element type {spec_element} for {key}")

def get(self, field: str, default=None) -> DEFINITION_VALUE_TYPES:
return self._data.get(field, default)
return self._metadata.get(field, self.config_data.get(field, default))

def get_matching(self, label: str) -> Mapping[str, DEFINITION_VALUE_TYPES] | None:
element = self.spec.element_for_label(label)
if not element:
return None
result = {}
for k, v in self._data.items():
if element.matches(k):
result[k] = v
for items in (self._metadata.items(), self.config_data.items()):
for k, v in items:
if element.matches(k):
result[k] = v
return result

def to_dynamic_sut_metadata(self) -> DynamicSUTMetadata:
return DynamicSUTMetadata(
model=self._data["model"], # type: ignore
driver=self._data["driver"], # type: ignore
maker=self._data.get("maker", None), # type: ignore
provider=self._data.get("provider", None), # type: ignore
date=self._data.get("date", None), # type: ignore
model=self._metadata["model"], # type: ignore
driver=self._metadata["driver"], # type: ignore
maker=self._metadata.get("maker", None), # type: ignore
provider=self._metadata.get("provider", None), # type: ignore
date=self._metadata.get("date", None), # type: ignore
)

def external_model_name(self) -> str:
Expand Down Expand Up @@ -231,7 +259,6 @@ class SUTUIDGenerator:
order = (
"moderated",
"reasoning",
"display_name",
"base_url", # for OpenAI-compatible SUTs
)
field_separator = RICH_UID_FIELD_SEPARATOR
Expand Down
12 changes: 11 additions & 1 deletion src/modelgauge/sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class SUTNotFoundException(Exception):
pass


class IncompatibleSUTParamsError(Exception):
pass


class SUTType(Enum):
DYNAMIC = "dynamic"
KNOWN = "known"
Expand Down Expand Up @@ -183,7 +187,13 @@ def _make_dynamic_sut(self, uid: str) -> SUT:
factory = self.dynamic_sut_factories.get(sut_definition.get("driver")) # type: ignore
if not factory:
raise UnknownSUTMakerError(f'Don\'t know how to make dynamic sut "{uid}"')
return factory.make_sut(sut_definition)
try:
sut = factory.make_sut(sut_definition, **sut_definition.config_data)
except TypeError:
raise IncompatibleSUTParamsError(
f"The {factory.__class__.__name__} factory cannot handle some dynamic SUT parameters specified in the uid: {sut_definition.config_data}."
)
return sut

def keys(self) -> list[str]:
"""Mimic the registry interface."""
Expand Down
5 changes: 5 additions & 0 deletions tests/modelgauge_tests/test_dynamic_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def make_sut(self, sut_definition: SUTDefinition):
return FakeSUT(sut_definition.dynamic_uid)


class FakeDynamicFactoryHandlesMod(FakeDynamicFactory):
def make_sut(self, sut_definition: SUTDefinition, moderated: bool = False):
return FakeSUT(sut_definition.dynamic_uid)


def test_injected_secrets():
factory = FakeDynamicFactory(
{"some-scope": {"some-key": "some-value"}, "optional-scope": {"optional-key": "optional-value"}}
Expand Down
5 changes: 4 additions & 1 deletion tests/modelgauge_tests/test_sut_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ def test_to_dynamic_sut_metadata():


def test_parse_rich_sut_uid():
uid = "google/gemma-3-27b-it:nebius:hfrelay;url=https://example.com/"
uid = "google/gemma-3-27b-it:nebius:hfrelay;reas=y;url=https://example.com/"
definition = SUTDefinition.parse(uid)
assert definition.get("model") == "gemma-3-27b-it"
assert definition.get("maker") == "google"
assert definition.get("driver") == "hfrelay"
assert definition.get("provider") == "nebius"
assert definition.get("base_url") == "https://example.com/"
assert definition.get("reasoning") is True

assert definition.uid == uid


def test_vllm_parameters():
Expand Down
20 changes: 17 additions & 3 deletions tests/modelgauge_tests/test_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from modelgauge.dynamic_sut_factory import UnknownSUTMakerError
from modelgauge.instance_factory import InstanceFactory
from modelgauge.sut import SUT
from modelgauge.sut_factory import SUTFactory, SUTNotFoundException, SUTType
from modelgauge.sut_factory import IncompatibleSUTParamsError, SUTFactory, SUTNotFoundException, SUTType
from modelgauge_tests.fake_sut import FakeSUT
from modelgauge_tests.test_dynamic_sut_factory import FakeDynamicFactory
from modelgauge_tests.test_dynamic_sut_factory import FakeDynamicFactory, FakeDynamicFactoryHandlesMod

KNOWN_UID = "known"
UNKNOWN_UID = "pleasedontregisterasutwiththisuid"
Expand All @@ -26,7 +26,11 @@ def sut_factory():
def sut_factory_dynamic():
"""SUT factory that patches the dynamic SUT factories."""
registry = InstanceFactory[SUT]()
dynamic_factories = {"driver1": FakeDynamicFactory({}), "driver2": FakeDynamicFactory({})}
dynamic_factories = {
"driver1": FakeDynamicFactory({}),
"driver2": FakeDynamicFactory({}),
"mod_driver": FakeDynamicFactoryHandlesMod({}),
}
with patch(
"modelgauge.sut_factory.SUTFactory._load_dynamic_sut_factories",
return_value=dynamic_factories,
Expand Down Expand Up @@ -67,6 +71,16 @@ def test_make_instance_dynamic_unknown_driver(sut_factory_dynamic):
sut_factory_dynamic.make_instance("google/gemma:unknown", secrets={})


def test_make_instance_dynamic_with_params(sut_factory_dynamic):
sut = sut_factory_dynamic.make_instance("google/gemma:mod_driver;mod=y", secrets={})
assert isinstance(sut, FakeSUT)


def test_make_instance_dynamic_incompatible_params(sut_factory_dynamic):
with pytest.raises(IncompatibleSUTParamsError):
sut_factory_dynamic.make_instance("google/gemma:driver1;mod=y", secrets={})


def test_make_instance_unknown_type(sut_factory):
with pytest.raises(SUTNotFoundException):
sut_factory.make_instance(UNKNOWN_UID, secrets={})
Expand Down
Loading