Skip to content

Commit 796e370

Browse files
authored
Fix handling of inference params in info display (#141)
1 parent 8d4c6c1 commit 796e370

File tree

5 files changed

+144
-38
lines changed

5 files changed

+144
-38
lines changed

src/data_designer/cli/commands/list.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from data_designer.cli.repositories.model_repository import ModelRepository
77
from data_designer.cli.repositories.provider_repository import ProviderRepository
88
from data_designer.cli.ui import console, print_error, print_header, print_info, print_warning
9-
from data_designer.config.models import ModelConfig
109
from data_designer.config.utils.constants import DATA_DESIGNER_HOME, NordColor
1110

1211

@@ -76,37 +75,6 @@ def display_providers(provider_repo: ProviderRepository) -> None:
7675
console.print()
7776

7877

79-
def format_inference_parameters(model_config: ModelConfig) -> str:
80-
"""Format inference parameters based on generation type.
81-
82-
Args:
83-
model_config: Model configuration
84-
85-
Returns:
86-
Formatted string of inference parameters
87-
"""
88-
params = model_config.inference_parameters
89-
90-
# Get parameter values as dict, excluding common base parameters
91-
params_dict = params.model_dump(exclude_none=True, mode="json")
92-
93-
if not params_dict:
94-
return "(none)"
95-
96-
# Format each parameter
97-
parts = []
98-
for key, value in params_dict.items():
99-
# Check if value is a distribution (has dict structure with distribution_type)
100-
if isinstance(value, dict) and "distribution_type" in value:
101-
formatted_value = "dist"
102-
elif isinstance(value, float):
103-
formatted_value = f"{value:.2f}"
104-
else:
105-
formatted_value = str(value)
106-
parts.append(f"{key}={formatted_value}")
107-
return ", ".join(parts)
108-
109-
11078
def display_models(model_repo: ModelRepository) -> None:
11179
"""Load and display model configurations.
11280
@@ -132,7 +100,7 @@ def display_models(model_repo: ModelRepository) -> None:
132100
table.add_column("Inference Parameters", style=NordColor.NORD15.value)
133101

134102
for mc in registry.model_configs:
135-
params_display = format_inference_parameters(mc)
103+
params_display = mc.inference_parameters.format_for_display()
136104

137105
table.add_row(
138106
mc.alias,

src/data_designer/config/models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,37 @@ def generate_kwargs(self) -> dict[str, Any]:
239239
result["extra_body"] = self.extra_body
240240
return result
241241

242+
def format_for_display(self) -> str:
243+
"""Format inference parameters for display.
244+
245+
Returns:
246+
Formatted string of inference parameters
247+
"""
248+
params_dict = self.model_dump(exclude_none=True, mode="json")
249+
250+
if not params_dict:
251+
return "(none)"
252+
253+
parts = []
254+
for key, value in params_dict.items():
255+
formatted_value = self._format_value(key, value)
256+
parts.append(f"{key}={formatted_value}")
257+
return ", ".join(parts)
258+
259+
def _format_value(self, key: str, value: Any) -> str:
260+
"""Format a single parameter value. Override in subclasses for custom formatting.
261+
262+
Args:
263+
key: Parameter name
264+
value: Parameter value
265+
266+
Returns:
267+
Formatted string representation of the value
268+
"""
269+
if isinstance(value, float):
270+
return f"{value:.2f}"
271+
return str(value)
272+
242273

243274
class ChatCompletionInferenceParams(BaseInferenceParams):
244275
"""Configuration for LLM inference parameters.
@@ -311,6 +342,20 @@ def _run_validation(
311342
def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
312343
return min_value <= value <= max_value
313344

345+
def _format_value(self, key: str, value: Any) -> str:
346+
"""Format chat completion parameter values, including distributions.
347+
348+
Args:
349+
key: Parameter name
350+
value: Parameter value
351+
352+
Returns:
353+
Formatted string representation of the value
354+
"""
355+
if isinstance(value, dict) and "distribution_type" in value:
356+
return "dist"
357+
return super()._format_value(key, value)
358+
314359

315360
# Maintain backwards compatibility with a deprecation warning
316361
class InferenceParameters(ChatCompletionInferenceParams):

src/data_designer/config/utils/visualization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
321321
table_model_configs.add_column("Alias")
322322
table_model_configs.add_column("Model")
323323
table_model_configs.add_column("Provider")
324-
table_model_configs.add_column("Temperature")
325-
table_model_configs.add_column("Top P")
324+
table_model_configs.add_column("Inference Parameters")
326325
for model_config in model_configs:
326+
params_display = model_config.inference_parameters.format_for_display()
327+
327328
table_model_configs.add_row(
328329
model_config.alias,
329330
model_config.model,
330331
model_config.provider,
331-
str(model_config.inference_parameters.temperature),
332-
str(model_config.inference_parameters.top_p),
332+
params_display,
333333
)
334334
group_args: list = [Rule(title="Model Configs"), table_model_configs]
335335
if len(model_configs) == 0:

tests/config/test_models.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,92 @@ def test_model_config_generation_type_from_dict():
296296
)
297297
assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams)
298298
assert model_config.generation_type == GenerationType.CHAT_COMPLETION
299+
300+
301+
def test_chat_completion_params_format_for_display_all_params():
302+
"""Test formatting chat completion model with all parameters."""
303+
params = ChatCompletionInferenceParams(
304+
temperature=0.7,
305+
top_p=0.9,
306+
max_tokens=2048,
307+
max_parallel_requests=4,
308+
timeout=60,
309+
)
310+
result = params.format_for_display()
311+
assert "generation_type=chat-completion" in result
312+
assert "temperature=0.70" in result
313+
assert "top_p=0.90" in result
314+
assert "max_tokens=2048" in result
315+
assert "max_parallel_requests=4" in result
316+
assert "timeout=60" in result
317+
318+
319+
def test_chat_completion_params_format_for_display_partial_params():
320+
"""Test formatting chat completion model with partial parameters (some None)."""
321+
params = ChatCompletionInferenceParams(
322+
temperature=0.5,
323+
max_tokens=1024,
324+
)
325+
result = params.format_for_display()
326+
assert "generation_type=chat-completion" in result
327+
assert "temperature=0.50" in result
328+
assert "max_tokens=1024" in result
329+
# None values should be excluded
330+
assert "top_p" not in result
331+
assert "timeout" not in result
332+
333+
334+
def test_embedding_params_format_for_display():
335+
"""Test formatting embedding model parameters."""
336+
params = EmbeddingInferenceParams(
337+
encoding_format="float",
338+
dimensions=1024,
339+
max_parallel_requests=8,
340+
)
341+
result = params.format_for_display()
342+
assert "generation_type=embedding" in result
343+
assert "encoding_format=float" in result
344+
assert "dimensions=1024" in result
345+
assert "max_parallel_requests=8" in result
346+
# Chat completion params should not appear
347+
assert "temperature" not in result
348+
assert "top_p" not in result
349+
350+
351+
def test_chat_completion_params_format_for_display_with_distribution():
352+
"""Test formatting parameters with distribution (should show 'dist')."""
353+
params = ChatCompletionInferenceParams(
354+
temperature=UniformDistribution(
355+
distribution_type="uniform",
356+
params=UniformDistributionParams(low=0.5, high=0.9),
357+
),
358+
max_tokens=2048,
359+
)
360+
result = params.format_for_display()
361+
assert "generation_type=chat-completion" in result
362+
assert "temperature=dist" in result
363+
assert "max_tokens=2048" in result
364+
365+
366+
def test_inference_params_format_for_display_float_formatting():
367+
"""Test that float values are formatted to 2 decimal places."""
368+
params = ChatCompletionInferenceParams(
369+
temperature=0.123456,
370+
top_p=0.987654,
371+
)
372+
result = params.format_for_display()
373+
assert "temperature=0.12" in result
374+
assert "top_p=0.99" in result
375+
376+
377+
def test_inference_params_format_for_display_minimal_params():
378+
"""Test formatting with only required parameters."""
379+
params = ChatCompletionInferenceParams()
380+
result = params.format_for_display()
381+
assert "generation_type=chat-completion" in result
382+
assert "max_parallel_requests=4" in result # Default value
383+
# Optional params should not appear when None
384+
assert "temperature" not in result
385+
assert "top_p" not in result
386+
assert "max_tokens" not in result
387+
assert "timeout" not in result

tests/config/utils/test_visualization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
from data_designer.config.config_builder import DataDesignerConfigBuilder
1010
from data_designer.config.utils.code_lang import CodeLang
11-
from data_designer.config.utils.visualization import display_sample_record, get_truncated_list_as_string, mask_api_key
11+
from data_designer.config.utils.visualization import (
12+
display_sample_record,
13+
get_truncated_list_as_string,
14+
mask_api_key,
15+
)
1216
from data_designer.config.validator_params import CodeValidatorParams
1317

1418

0 commit comments

Comments
 (0)