diff --git a/src/data_designer/cli/commands/list.py b/src/data_designer/cli/commands/list.py index a93b9d31..fb7d51b1 100644 --- a/src/data_designer/cli/commands/list.py +++ b/src/data_designer/cli/commands/list.py @@ -6,7 +6,6 @@ from data_designer.cli.repositories.model_repository import ModelRepository from data_designer.cli.repositories.provider_repository import ProviderRepository from data_designer.cli.ui import console, print_error, print_header, print_info, print_warning -from data_designer.config.models import ModelConfig from data_designer.config.utils.constants import DATA_DESIGNER_HOME, NordColor @@ -76,37 +75,6 @@ def display_providers(provider_repo: ProviderRepository) -> None: console.print() -def format_inference_parameters(model_config: ModelConfig) -> str: - """Format inference parameters based on generation type. - - Args: - model_config: Model configuration - - Returns: - Formatted string of inference parameters - """ - params = model_config.inference_parameters - - # Get parameter values as dict, excluding common base parameters - params_dict = params.model_dump(exclude_none=True, mode="json") - - if not params_dict: - return "(none)" - - # Format each parameter - parts = [] - for key, value in params_dict.items(): - # Check if value is a distribution (has dict structure with distribution_type) - if isinstance(value, dict) and "distribution_type" in value: - formatted_value = "dist" - elif isinstance(value, float): - formatted_value = f"{value:.2f}" - else: - formatted_value = str(value) - parts.append(f"{key}={formatted_value}") - return ", ".join(parts) - - def display_models(model_repo: ModelRepository) -> None: """Load and display model configurations. @@ -132,7 +100,7 @@ def display_models(model_repo: ModelRepository) -> None: table.add_column("Inference Parameters", style=NordColor.NORD15.value) for mc in registry.model_configs: - params_display = format_inference_parameters(mc) + params_display = mc.inference_parameters.format_for_display() table.add_row( mc.alias, diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 9b5ac6d7..bb08fb58 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -239,6 +239,37 @@ def generate_kwargs(self) -> dict[str, Any]: result["extra_body"] = self.extra_body return result + def format_for_display(self) -> str: + """Format inference parameters for display. + + Returns: + Formatted string of inference parameters + """ + params_dict = self.model_dump(exclude_none=True, mode="json") + + if not params_dict: + return "(none)" + + parts = [] + for key, value in params_dict.items(): + formatted_value = self._format_value(key, value) + parts.append(f"{key}={formatted_value}") + return ", ".join(parts) + + def _format_value(self, key: str, value: Any) -> str: + """Format a single parameter value. Override in subclasses for custom formatting. + + Args: + key: Parameter name + value: Parameter value + + Returns: + Formatted string representation of the value + """ + if isinstance(value, float): + return f"{value:.2f}" + return str(value) + class ChatCompletionInferenceParams(BaseInferenceParams): """Configuration for LLM inference parameters. @@ -311,6 +342,20 @@ def _run_validation( def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool: return min_value <= value <= max_value + def _format_value(self, key: str, value: Any) -> str: + """Format chat completion parameter values, including distributions. + + Args: + key: Parameter name + value: Parameter value + + Returns: + Formatted string representation of the value + """ + if isinstance(value, dict) and "distribution_type" in value: + return "dist" + return super()._format_value(key, value) + # Maintain backwards compatibility with a deprecation warning class InferenceParameters(ChatCompletionInferenceParams): diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index c2cca08b..85a230a9 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -321,15 +321,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None: table_model_configs.add_column("Alias") table_model_configs.add_column("Model") table_model_configs.add_column("Provider") - table_model_configs.add_column("Temperature") - table_model_configs.add_column("Top P") + table_model_configs.add_column("Inference Parameters") for model_config in model_configs: + params_display = model_config.inference_parameters.format_for_display() + table_model_configs.add_row( model_config.alias, model_config.model, model_config.provider, - str(model_config.inference_parameters.temperature), - str(model_config.inference_parameters.top_p), + params_display, ) group_args: list = [Rule(title="Model Configs"), table_model_configs] if len(model_configs) == 0: diff --git a/tests/config/test_models.py b/tests/config/test_models.py index 6f6618c2..a4a51579 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -296,3 +296,92 @@ def test_model_config_generation_type_from_dict(): ) assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams) assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + +def test_chat_completion_params_format_for_display_all_params(): + """Test formatting chat completion model with all parameters.""" + params = ChatCompletionInferenceParams( + temperature=0.7, + top_p=0.9, + max_tokens=2048, + max_parallel_requests=4, + timeout=60, + ) + result = params.format_for_display() + assert "generation_type=chat-completion" in result + assert "temperature=0.70" in result + assert "top_p=0.90" in result + assert "max_tokens=2048" in result + assert "max_parallel_requests=4" in result + assert "timeout=60" in result + + +def test_chat_completion_params_format_for_display_partial_params(): + """Test formatting chat completion model with partial parameters (some None).""" + params = ChatCompletionInferenceParams( + temperature=0.5, + max_tokens=1024, + ) + result = params.format_for_display() + assert "generation_type=chat-completion" in result + assert "temperature=0.50" in result + assert "max_tokens=1024" in result + # None values should be excluded + assert "top_p" not in result + assert "timeout" not in result + + +def test_embedding_params_format_for_display(): + """Test formatting embedding model parameters.""" + params = EmbeddingInferenceParams( + encoding_format="float", + dimensions=1024, + max_parallel_requests=8, + ) + result = params.format_for_display() + assert "generation_type=embedding" in result + assert "encoding_format=float" in result + assert "dimensions=1024" in result + assert "max_parallel_requests=8" in result + # Chat completion params should not appear + assert "temperature" not in result + assert "top_p" not in result + + +def test_chat_completion_params_format_for_display_with_distribution(): + """Test formatting parameters with distribution (should show 'dist').""" + params = ChatCompletionInferenceParams( + temperature=UniformDistribution( + distribution_type="uniform", + params=UniformDistributionParams(low=0.5, high=0.9), + ), + max_tokens=2048, + ) + result = params.format_for_display() + assert "generation_type=chat-completion" in result + assert "temperature=dist" in result + assert "max_tokens=2048" in result + + +def test_inference_params_format_for_display_float_formatting(): + """Test that float values are formatted to 2 decimal places.""" + params = ChatCompletionInferenceParams( + temperature=0.123456, + top_p=0.987654, + ) + result = params.format_for_display() + assert "temperature=0.12" in result + assert "top_p=0.99" in result + + +def test_inference_params_format_for_display_minimal_params(): + """Test formatting with only required parameters.""" + params = ChatCompletionInferenceParams() + result = params.format_for_display() + assert "generation_type=chat-completion" in result + assert "max_parallel_requests=4" in result # Default value + # Optional params should not appear when None + assert "temperature" not in result + assert "top_p" not in result + assert "max_tokens" not in result + assert "timeout" not in result diff --git a/tests/config/utils/test_visualization.py b/tests/config/utils/test_visualization.py index fa55824e..bb77e895 100644 --- a/tests/config/utils/test_visualization.py +++ b/tests/config/utils/test_visualization.py @@ -8,7 +8,11 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.utils.code_lang import CodeLang -from data_designer.config.utils.visualization import display_sample_record, get_truncated_list_as_string, mask_api_key +from data_designer.config.utils.visualization import ( + display_sample_record, + get_truncated_list_as_string, + mask_api_key, +) from data_designer.config.validator_params import CodeValidatorParams