Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions src/data_designer/cli/commands/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from data_designer.cli.repositories.model_repository import ModelRepository
from data_designer.cli.repositories.provider_repository import ProviderRepository
from data_designer.cli.ui import console, print_error, print_header, print_info, print_warning
from data_designer.config.models import ModelConfig
from data_designer.config.utils.constants import DATA_DESIGNER_HOME, NordColor


Expand Down Expand Up @@ -76,37 +75,6 @@ def display_providers(provider_repo: ProviderRepository) -> None:
console.print()


def format_inference_parameters(model_config: ModelConfig) -> str:
"""Format inference parameters based on generation type.

Args:
model_config: Model configuration

Returns:
Formatted string of inference parameters
"""
params = model_config.inference_parameters

# Get parameter values as dict, excluding common base parameters
params_dict = params.model_dump(exclude_none=True, mode="json")

if not params_dict:
return "(none)"

# Format each parameter
parts = []
for key, value in params_dict.items():
# Check if value is a distribution (has dict structure with distribution_type)
if isinstance(value, dict) and "distribution_type" in value:
formatted_value = "dist"
elif isinstance(value, float):
formatted_value = f"{value:.2f}"
else:
formatted_value = str(value)
parts.append(f"{key}={formatted_value}")
return ", ".join(parts)


def display_models(model_repo: ModelRepository) -> None:
"""Load and display model configurations.

Expand All @@ -132,7 +100,7 @@ def display_models(model_repo: ModelRepository) -> None:
table.add_column("Inference Parameters", style=NordColor.NORD15.value)

for mc in registry.model_configs:
params_display = format_inference_parameters(mc)
params_display = mc.inference_parameters.format_for_display()

table.add_row(
mc.alias,
Expand Down
45 changes: 45 additions & 0 deletions src/data_designer/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,37 @@ def generate_kwargs(self) -> dict[str, Any]:
result["extra_body"] = self.extra_body
return result

def format_for_display(self) -> str:
"""Format inference parameters for display.

Returns:
Formatted string of inference parameters
"""
params_dict = self.model_dump(exclude_none=True, mode="json")

if not params_dict:
return "(none)"

parts = []
for key, value in params_dict.items():
formatted_value = self._format_value(key, value)
parts.append(f"{key}={formatted_value}")
return ", ".join(parts)

def _format_value(self, key: str, value: Any) -> str:
"""Format a single parameter value. Override in subclasses for custom formatting.

Args:
key: Parameter name
value: Parameter value

Returns:
Formatted string representation of the value
"""
if isinstance(value, float):
return f"{value:.2f}"
return str(value)


class ChatCompletionInferenceParams(BaseInferenceParams):
"""Configuration for LLM inference parameters.
Expand Down Expand Up @@ -311,6 +342,20 @@ def _run_validation(
def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
return min_value <= value <= max_value

def _format_value(self, key: str, value: Any) -> str:
"""Format chat completion parameter values, including distributions.

Args:
key: Parameter name
value: Parameter value

Returns:
Formatted string representation of the value
"""
if isinstance(value, dict) and "distribution_type" in value:
return "dist"
return super()._format_value(key, value)


# Maintain backwards compatibility with a deprecation warning
class InferenceParameters(ChatCompletionInferenceParams):
Expand Down
8 changes: 4 additions & 4 deletions src/data_designer/config/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
table_model_configs.add_column("Alias")
table_model_configs.add_column("Model")
table_model_configs.add_column("Provider")
table_model_configs.add_column("Temperature")
table_model_configs.add_column("Top P")
table_model_configs.add_column("Inference Parameters")
for model_config in model_configs:
params_display = model_config.inference_parameters.format_for_display()

table_model_configs.add_row(
model_config.alias,
model_config.model,
model_config.provider,
str(model_config.inference_parameters.temperature),
str(model_config.inference_parameters.top_p),
params_display,
)
group_args: list = [Rule(title="Model Configs"), table_model_configs]
if len(model_configs) == 0:
Expand Down
89 changes: 89 additions & 0 deletions tests/config/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,92 @@ def test_model_config_generation_type_from_dict():
)
assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams)
assert model_config.generation_type == GenerationType.CHAT_COMPLETION


def test_chat_completion_params_format_for_display_all_params():
"""Test formatting chat completion model with all parameters."""
params = ChatCompletionInferenceParams(
temperature=0.7,
top_p=0.9,
max_tokens=2048,
max_parallel_requests=4,
timeout=60,
)
result = params.format_for_display()
assert "generation_type=chat-completion" in result
assert "temperature=0.70" in result
assert "top_p=0.90" in result
assert "max_tokens=2048" in result
assert "max_parallel_requests=4" in result
assert "timeout=60" in result


def test_chat_completion_params_format_for_display_partial_params():
"""Test formatting chat completion model with partial parameters (some None)."""
params = ChatCompletionInferenceParams(
temperature=0.5,
max_tokens=1024,
)
result = params.format_for_display()
assert "generation_type=chat-completion" in result
assert "temperature=0.50" in result
assert "max_tokens=1024" in result
# None values should be excluded
assert "top_p" not in result
assert "timeout" not in result


def test_embedding_params_format_for_display():
"""Test formatting embedding model parameters."""
params = EmbeddingInferenceParams(
encoding_format="float",
dimensions=1024,
max_parallel_requests=8,
)
result = params.format_for_display()
assert "generation_type=embedding" in result
assert "encoding_format=float" in result
assert "dimensions=1024" in result
assert "max_parallel_requests=8" in result
# Chat completion params should not appear
assert "temperature" not in result
assert "top_p" not in result


def test_chat_completion_params_format_for_display_with_distribution():
"""Test formatting parameters with distribution (should show 'dist')."""
params = ChatCompletionInferenceParams(
temperature=UniformDistribution(
distribution_type="uniform",
params=UniformDistributionParams(low=0.5, high=0.9),
),
max_tokens=2048,
)
result = params.format_for_display()
assert "generation_type=chat-completion" in result
assert "temperature=dist" in result
assert "max_tokens=2048" in result


def test_inference_params_format_for_display_float_formatting():
"""Test that float values are formatted to 2 decimal places."""
params = ChatCompletionInferenceParams(
temperature=0.123456,
top_p=0.987654,
)
result = params.format_for_display()
assert "temperature=0.12" in result
assert "top_p=0.99" in result


def test_inference_params_format_for_display_minimal_params():
"""Test formatting with only required parameters."""
params = ChatCompletionInferenceParams()
result = params.format_for_display()
assert "generation_type=chat-completion" in result
assert "max_parallel_requests=4" in result # Default value
# Optional params should not appear when None
assert "temperature" not in result
assert "top_p" not in result
assert "max_tokens" not in result
assert "timeout" not in result
6 changes: 5 additions & 1 deletion tests/config/utils/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from data_designer.config.config_builder import DataDesignerConfigBuilder
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.utils.visualization import display_sample_record, get_truncated_list_as_string, mask_api_key
from data_designer.config.utils.visualization import (
display_sample_record,
get_truncated_list_as_string,
mask_api_key,
)
from data_designer.config.validator_params import CodeValidatorParams


Expand Down