Skip to content

Commit 60c1aed

Browse files
authored
chore: update default model config settings (#142)
* update default model config settings * model id -> model in cli * fix unit tests
1 parent 796e370 commit 60c1aed

File tree

6 files changed

+50
-57
lines changed

6 files changed

+50
-57
lines changed

src/data_designer/cli/commands/list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def display_models(model_repo: ModelRepository) -> None:
9595
# Display as table
9696
table = Table(title="Model Configurations", border_style=NordColor.NORD8.value)
9797
table.add_column("Alias", style=NordColor.NORD14.value, no_wrap=True)
98-
table.add_column("Model ID", style=NordColor.NORD4.value)
98+
table.add_column("Model", style=NordColor.NORD4.value)
9999
table.add_column("Provider", style=NordColor.NORD9.value, no_wrap=True)
100100
table.add_column("Inference Parameters", style=NordColor.NORD15.value)
101101

src/data_designer/cli/forms/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
4242
fields.append(
4343
TextField(
4444
"model",
45-
"Model ID",
45+
"Model",
4646
default=initial_data.get("model") if initial_data else None,
4747
required=True,
48-
validator=lambda x: (False, "Model ID is required") if not x else (True, None),
48+
validator=lambda x: (False, "Model is required") if not x else (True, None),
4949
)
5050
)
5151

src/data_designer/config/default_model_settings.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,57 +27,32 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30-
def get_default_text_alias_inference_parameters() -> ChatCompletionInferenceParams:
31-
return ChatCompletionInferenceParams(
32-
temperature=0.85,
33-
top_p=0.95,
34-
)
35-
36-
37-
def get_default_reasoning_alias_inference_parameters() -> ChatCompletionInferenceParams:
38-
return ChatCompletionInferenceParams(
39-
temperature=0.35,
40-
top_p=0.95,
41-
)
42-
43-
44-
def get_default_vision_alias_inference_parameters() -> ChatCompletionInferenceParams:
45-
return ChatCompletionInferenceParams(
46-
temperature=0.85,
47-
top_p=0.95,
48-
)
49-
50-
51-
def get_default_embedding_alias_inference_parameters(provider: str) -> EmbeddingInferenceParams:
52-
args = dict(encoding_format="float")
53-
if provider == "nvidia":
54-
args["extra_body"] = {"input_type": "query"}
55-
return EmbeddingInferenceParams(**args)
56-
57-
5830
def get_default_inference_parameters(
59-
model_alias: Literal["text", "reasoning", "vision", "embedding"], provider: str
31+
model_alias: Literal["text", "reasoning", "vision", "embedding"],
32+
inference_parameters: dict[str, Any],
6033
) -> InferenceParamsT:
6134
if model_alias == "reasoning":
62-
return get_default_reasoning_alias_inference_parameters()
35+
return ChatCompletionInferenceParams(**inference_parameters)
6336
elif model_alias == "vision":
64-
return get_default_vision_alias_inference_parameters()
37+
return ChatCompletionInferenceParams(**inference_parameters)
6538
elif model_alias == "embedding":
66-
return get_default_embedding_alias_inference_parameters(provider)
39+
return EmbeddingInferenceParams(**inference_parameters)
6740
else:
68-
return get_default_text_alias_inference_parameters()
41+
return ChatCompletionInferenceParams(**inference_parameters)
6942

7043

7144
def get_builtin_model_configs() -> list[ModelConfig]:
7245
model_configs = []
7346
for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items():
74-
for model_alias, model_id in model_alias_map.items():
47+
for model_alias, settings in model_alias_map.items():
7548
model_configs.append(
7649
ModelConfig(
7750
alias=f"{provider}-{model_alias}",
78-
model=model_id,
51+
model=settings["model"],
7952
provider=provider,
80-
inference_parameters=get_default_inference_parameters(model_alias, provider),
53+
inference_parameters=get_default_inference_parameters(
54+
model_alias, settings["inference_parameters"]
55+
),
8156
)
8257
)
8358
return model_configs

src/data_designer/config/utils/constants.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,27 @@ class NordColor(Enum):
299299
},
300300
]
301301

302+
303+
DEFAULT_TEXT_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
304+
DEFAULT_REASONING_INFERENCE_PARAMS = {"temperature": 0.35, "top_p": 0.95}
305+
DEFAULT_VISION_INFERENCE_PARAMS = {"temperature": 0.85, "top_p": 0.95}
306+
DEFAULT_EMBEDDING_INFERENCE_PARAMS = {"encoding_format": "float"}
307+
308+
302309
PREDEFINED_PROVIDERS_MODEL_MAP = {
303310
NVIDIA_PROVIDER_NAME: {
304-
"text": "nvidia/nemotron-3-nano-30b-a3b",
305-
"reasoning": "openai/gpt-oss-20b",
306-
"vision": "nvidia/nemotron-nano-12b-v2-vl",
307-
"embedding": "nvidia/llama-3.2-nv-embedqa-1b-v2",
311+
"text": {"model": "nvidia/nemotron-3-nano-30b-a3b", "inference_parameters": {"temperature": 1.0, "top_p": 1.0}},
312+
"reasoning": {"model": "openai/gpt-oss-20b", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
313+
"vision": {"model": "nvidia/nemotron-nano-12b-v2-vl", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
314+
"embedding": {
315+
"model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
316+
"inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS | {"extra_body": {"input_type": "query"}},
317+
},
308318
},
309319
OPENAI_PROVIDER_NAME: {
310-
"text": "gpt-4.1",
311-
"reasoning": "gpt-5",
312-
"vision": "gpt-5",
313-
"embedding": "text-embedding-3-large",
320+
"text": {"model": "gpt-4.1", "inference_parameters": DEFAULT_TEXT_INFERENCE_PARAMS},
321+
"reasoning": {"model": "gpt-5", "inference_parameters": DEFAULT_REASONING_INFERENCE_PARAMS},
322+
"vision": {"model": "gpt-5", "inference_parameters": DEFAULT_VISION_INFERENCE_PARAMS},
323+
"embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS},
314324
},
315325
}

tests/cli/forms/test_model_builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,19 @@ def test_alias_field_accepts_any_alias_when_no_existing() -> None:
5151
assert alias_field.value == "my-model"
5252

5353

54-
# Model ID validation tests
55-
def test_model_id_field_rejects_empty_string() -> None:
54+
# Model validation tests
55+
def test_model_field_rejects_empty_string() -> None:
5656
"""Test model ID field rejects empty strings."""
5757
builder = ModelFormBuilder()
5858
form = builder.create_form()
5959
model_field = form.get_field("model")
6060

61-
with pytest.raises(ValidationError, match="Model ID is required"):
61+
with pytest.raises(ValidationError, match="Model is required"):
6262
model_field.value = ""
6363

6464

65-
def test_model_id_field_accepts_any_non_empty_string() -> None:
66-
"""Test model ID field accepts any non-empty string."""
65+
def test_model_field_accepts_any_non_empty_string() -> None:
66+
"""Test model field accepts any non-empty string."""
6767
builder = ModelFormBuilder()
6868
form = builder.create_form()
6969
model_field = form.get_field("model")

tests/config/test_default_model_settings.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,31 @@
2323

2424

2525
def test_get_default_inference_parameters():
26-
assert get_default_inference_parameters("text", "nvidia") == ChatCompletionInferenceParams(
26+
assert get_default_inference_parameters(
27+
"text", {"temperature": 0.85, "top_p": 0.95}
28+
) == ChatCompletionInferenceParams(
2729
temperature=0.85,
2830
top_p=0.95,
2931
)
30-
assert get_default_inference_parameters("reasoning", "nvidia") == ChatCompletionInferenceParams(
32+
assert get_default_inference_parameters(
33+
"reasoning", {"temperature": 0.35, "top_p": 0.95}
34+
) == ChatCompletionInferenceParams(
3135
temperature=0.35,
3236
top_p=0.95,
3337
)
34-
assert get_default_inference_parameters("vision", "nvidia") == ChatCompletionInferenceParams(
38+
assert get_default_inference_parameters(
39+
"vision", {"temperature": 0.85, "top_p": 0.95}
40+
) == ChatCompletionInferenceParams(
3541
temperature=0.85,
3642
top_p=0.95,
3743
)
38-
assert get_default_inference_parameters("embedding", "nvidia") == EmbeddingInferenceParams(
44+
assert get_default_inference_parameters(
45+
"embedding", {"encoding_format": "float", "extra_body": {"input_type": "query"}}
46+
) == EmbeddingInferenceParams(
3947
encoding_format="float",
4048
extra_body={"input_type": "query"},
4149
)
42-
assert get_default_inference_parameters("embedding", "openai") == EmbeddingInferenceParams(
50+
assert get_default_inference_parameters("embedding", {"encoding_format": "float"}) == EmbeddingInferenceParams(
4351
encoding_format="float",
4452
)
4553

0 commit comments

Comments
 (0)