Skip to content

Commit 2ca5067

Browse files
authored
make InferenceParamters optional in model config (#46)
1 parent d0439fe commit 2ca5067

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/data_designer/config/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -
208208
class ModelConfig(ConfigBase):
209209
alias: str
210210
model: str
211-
inference_parameters: InferenceParameters
211+
inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
212212
provider: Optional[str] = None
213213

214214

tests/config/test_models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def test_image_context_validate_image_format():
4545
ImageContext(column_name="image_base64", data_type=ModalityDataType.BASE64)
4646

4747

48+
def test_inference_parameters_default_construction():
49+
empty_inference_parameters = InferenceParameters()
50+
assert empty_inference_parameters.generate_kwargs == {}
51+
assert empty_inference_parameters.max_parallel_requests == 4
52+
53+
4854
def test_inference_parameters_generate_kwargs():
4955
assert InferenceParameters(
5056
temperature=0.95,
@@ -203,8 +209,8 @@ def test_generation_parameters_max_tokens_validation():
203209

204210
def test_load_model_configs():
205211
stub_model_configs = [
206-
ModelConfig(alias="test", model="test", inference_parameters=InferenceParameters()),
207-
ModelConfig(alias="test2", model="test2", inference_parameters=InferenceParameters()),
212+
ModelConfig(alias="test", model="test"),
213+
ModelConfig(alias="test2", model="test2"),
208214
]
209215
stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs]
210216
assert load_model_configs([]) == []
@@ -240,3 +246,8 @@ def test_load_model_configs():
240246
tmp_file.write(json.dumps(invalid_model_configs).encode("utf-8"))
241247
tmp_file.flush()
242248
load_model_configs(tmp_file.name)
249+
250+
251+
def test_model_config_default_construction():
252+
model_config = ModelConfig(alias="test", model="test")
253+
assert model_config.inference_parameters == InferenceParameters()

0 commit comments

Comments
 (0)