diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index d19b6a9e..a3bef936 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -377,3 +377,64 @@ class SeedDatasetColumnConfig(SingleColumnConfig): """ column_type: Literal["seed-dataset"] = "seed-dataset" + + +class EmbeddingColumnConfig(SingleColumnConfig): + """Configuration for embedding generation columns. + + Embedding columns generate embeddings for text input using a specified model. + + Attributes: + column_type: Discriminator field, always "embedding" for this configuration type. + target_column: The column to generate embeddings for. The column could be a single text string or a list of text strings in stringified JSON format. + If it is a list of text strings in stringified JSON format, the embeddings will be generated for each text string. + model_alias: The model to use for embedding generation. + """ + + column_type: Literal["embedding"] = "embedding" + target_column: str + model_alias: str + + @property + def required_columns(self) -> list[str]: + return [self.target_column] + + +class ImageGenerationColumnConfig(SingleColumnConfig): + """Configuration for image generation columns. + + Image columns generate images using a specified model. + + Attributes: + column_type: Discriminator field, always "image-generation" for this configuration type. + prompt: Prompt template for image generation. Supports Jinja2 templating to + reference other columns (e.g., "Generate an image of a {{ character_name }}"). + Must be a valid Jinja2 template. + model_alias: The model to use for image generation. + """ + + column_type: Literal["image-generation"] = "image-generation" + prompt: str + model_alias: str + + @property + def required_columns(self) -> list[str]: + """Get columns referenced in the prompt template. + + Returns: + List of unique column names referenced in Jinja2 templates. + """ + return list(get_prompt_template_keywords(self.prompt)) + + @model_validator(mode="after") + def assert_prompt_valid_jinja(self) -> Self: + """Validate that prompt is a valid Jinja2 template. + + Returns: + The validated instance. + + Raises: + InvalidConfigError: If prompt contains invalid Jinja2 syntax. + """ + assert_valid_jinja2_template(self.prompt) + return self diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index 50ba498d..efdeb094 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -7,7 +7,9 @@ from ..plugin_manager import PluginManager from .column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -31,6 +33,8 @@ SamplerColumnConfig, SeedDatasetColumnConfig, ValidationColumnConfig, + EmbeddingColumnConfig, + ImageGenerationColumnConfig, ] ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) @@ -50,6 +54,8 @@ DataDesignerColumnType.SEED_DATASET: "🌱", DataDesignerColumnType.SAMPLER: "🎲", DataDesignerColumnType.VALIDATION: "🔍", + DataDesignerColumnType.EMBEDDING: "🧬", + DataDesignerColumnType.IMAGE_GENERATION: "🖼️", } COLUMN_TYPE_EMOJI_MAP.update( {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()} @@ -66,6 +72,8 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, + DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, } dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types @@ -79,6 +87,8 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType] DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, } llm_generated_column_types.update( plugin_manager.get_plugin_column_types( @@ -117,6 +127,10 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) if column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.EMBEDDING: + return EmbeddingColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.IMAGE_GENERATION: + return ImageGenerationColumnConfig(name=name, **kwargs) if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value): return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover @@ -131,6 +145,8 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] diff --git a/src/data_designer/config/default_model_settings.py b/src/data_designer/config/default_model_settings.py index 32c1d42b..cb565178 100644 --- a/src/data_designer/config/default_model_settings.py +++ b/src/data_designer/config/default_model_settings.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Any, Literal, Optional -from .models import InferenceParameters, ModelConfig, ModelProvider +from .models import CompletionInferenceParameters, ModelConfig, ModelProvider from .utils.constants import ( MANAGED_ASSETS_PATH, MODEL_CONFIGS_FILE_PATH, @@ -21,28 +21,30 @@ logger = logging.getLogger(__name__) -def get_default_text_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_text_alias_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) -def get_default_reasoning_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_reasoning_alias_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters( temperature=0.35, top_p=0.95, ) -def get_default_vision_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_vision_alias_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) -def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters: +def get_default_inference_parameters( + model_alias: Literal["text", "reasoning", "vision"], +) -> CompletionInferenceParameters: if model_alias == "reasoning": return get_default_reasoning_alias_inference_parameters() elif model_alias == "vision": @@ -103,7 +105,8 @@ def resolve_seed_default_model_settings() -> None: f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}" ) save_config_file( - MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]} + MODEL_CONFIGS_FILE_PATH, + {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]}, ) if not MODEL_PROVIDERS_FILE_PATH.exists(): @@ -111,7 +114,7 @@ def resolve_seed_default_model_settings() -> None: f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}" ) save_config_file( - MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]} + MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]} ) if not MANAGED_ASSETS_PATH.exists(): diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 6bff8efd..4b3ae12c 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -5,7 +5,7 @@ from enum import Enum import logging from pathlib import Path -from typing import Any, Generic, List, Optional, TypeVar, Union +from typing import Any, Generic, List, Literal, Optional, TypeVar, Union import numpy as np from pydantic import BaseModel, Field, model_validator @@ -136,17 +136,29 @@ def sample(self) -> float: DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution] -class InferenceParameters(ConfigBase): - temperature: Optional[Union[float, DistributionT]] = None - top_p: Optional[Union[float, DistributionT]] = None - max_tokens: Optional[int] = Field(default=None, ge=1) +class BaseInferenceParameters(ConfigBase, ABC): max_parallel_requests: int = Field(default=4, ge=1) timeout: Optional[int] = Field(default=None, ge=1) extra_body: Optional[dict[str, Any]] = None @property - def generate_kwargs(self) -> dict[str, Union[float, int]]: + def generate_kwargs(self) -> dict[str, Any]: result = {} + if self.timeout is not None: + result["timeout"] = self.timeout + if self.extra_body is not None and self.extra_body != {}: + result["extra_body"] = self.extra_body + return result + + +class CompletionInferenceParameters(BaseInferenceParameters): + temperature: Optional[Union[float, DistributionT]] = None + top_p: Optional[Union[float, DistributionT]] = None + max_tokens: Optional[int] = Field(default=None, ge=1) + + @property + def generate_kwargs(self) -> dict[str, Any]: + result = super().generate_kwargs if self.temperature is not None: result["temperature"] = ( self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature @@ -155,10 +167,6 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]: result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p if self.max_tokens is not None: result["max_tokens"] = self.max_tokens - if self.timeout is not None: - result["timeout"] = self.timeout - if self.extra_body is not None and self.extra_body != {}: - result["extra_body"] = self.extra_body return result @model_validator(mode="after") @@ -205,12 +213,89 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) - return min_value <= value <= max_value +# Maintain backwards compatibility with a deprecation warning +class InferenceParameters(CompletionInferenceParameters): + """ + Deprecated: Use CompletionInferenceParameters instead. + This alias will be removed in a future version. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + logger.warning( + "InferenceParameters is deprecated and will be removed in a future version. " + "Use CompletionInferenceParameters instead." + ) + super().__init__(*args, **kwargs) + + +class EmbeddingInferenceParameters(BaseInferenceParameters): + encoding_format: Optional[Literal["float", "base64"]] = None + dimensions: Optional[int] = None + + @property + def generate_kwargs(self) -> dict[str, Union[float, int]]: + result = super().generate_kwargs + if self.encoding_format is not None: + result["encoding_format"] = self.encoding_format + if self.dimensions is not None: + result["dimensions"] = self.dimensions + return result + + +class ImageGenerationInferenceParameters(BaseInferenceParameters): + quality: str + size: str + output_format: Optional[ModalityDataType] = ModalityDataType.BASE64 + + @property + def generate_kwargs(self) -> dict[str, Any]: + result = super().generate_kwargs + result["size"] = self.size + result["quality"] = self.quality + result["response_format"] = "b64_json" if self.output_format == ModalityDataType.BASE64 else self.output_format + return result + + +InferenceParametersT: TypeAlias = Union[ + InferenceParameters, CompletionInferenceParameters, EmbeddingInferenceParameters, ImageGenerationInferenceParameters +] + + +class GenerationType(str, Enum): + CHAT_COMPLETION = "chat-completion" + EMBEDDING = "embedding" + IMAGE_GENERATION = "image-generation" + + class ModelConfig(ConfigBase): alias: str model: str - inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters) + inference_parameters: InferenceParametersT = Field(default_factory=CompletionInferenceParameters) + generation_type: Optional[GenerationType] = Field(default=GenerationType.CHAT_COMPLETION) provider: Optional[str] = None + @model_validator(mode="after") + def _normalize_deprecated_inference_parameters(self) -> Self: + """Normalize deprecated InferenceParameters to CompletionInferenceParameters.""" + if isinstance(self.inference_parameters, InferenceParameters): + self.inference_parameters = CompletionInferenceParameters(**self.inference_parameters.model_dump()) + return self + + @model_validator(mode="after") + def _validate_generation_type(self) -> Self: + generation_type_instance_map = { + GenerationType.CHAT_COMPLETION: CompletionInferenceParameters, + GenerationType.EMBEDDING: EmbeddingInferenceParameters, + GenerationType.IMAGE_GENERATION: ImageGenerationInferenceParameters, + } + if self.generation_type not in generation_type_instance_map: + raise ValueError(f"Invalid generation type: {self.generation_type}") + if not isinstance(self.inference_parameters, generation_type_instance_map[self.generation_type]): + raise ValueError( + f"Inference parameters must be an instance of {generation_type_instance_map[self.generation_type].__name__!r} when generation_type is {self.generation_type!r}" + ) + return self + class ModelProvider(ConfigBase): name: str diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index 26ab4ad3..0972daf7 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -8,7 +8,7 @@ from functools import cached_property import json import os -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import pandas as pd @@ -171,6 +171,7 @@ def display_sample_record( + config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION) + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT) + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED) + + config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING) ) if len(non_code_columns) > 0: table = Table(title="Generated Columns", **table_kws) @@ -178,6 +179,10 @@ def display_sample_record( table.add_column("Value") for col in non_code_columns: if not col.drop: + if col.column_type == DataDesignerColumnType.EMBEDDING: + record[col.name]["embeddings"] = [ + get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings") + ] table.add_row(col.name, convert_to_row_element(record[col.name])) render_list.append(pad_console_element(table)) @@ -237,6 +242,14 @@ def display_sample_record( console.print(Group(*render_list), markup=False) +def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str: + if len(long_list) > max_items: + truncated_part = long_list[:max_items] + return f"[{', '.join(str(x) for x in truncated_part)} ...]" + else: + return str(long_list) + + def display_sampler_table( sampler_params: dict[SamplerType, ConfigBase], title: Optional[str] = None, diff --git a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 120caef4..1b23c0ea 100644 --- a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -23,7 +23,7 @@ SingleColumnConfig, ValidationColumnConfig, ) -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, create_response_recipe, diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index f4ddb60c..a98038b3 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -2,12 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +import functools +import logging from typing import overload import pandas as pd +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP +from data_designer.config.models import BaseInferenceParameters, ModelConfig from data_designer.config.utils.type_helpers import StrEnum from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT +from data_designer.engine.models.facade import ModelFacade + +logger = logging.getLogger(__name__) class GenerationStrategy(StrEnum): @@ -59,3 +66,30 @@ def can_generate_from_scratch(self) -> bool: @abstractmethod def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... + + +class WithModelGeneration: + @functools.cached_property + def model(self) -> ModelFacade: + return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) + + @functools.cached_property + def model_config(self) -> ModelConfig: + return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) + + @functools.cached_property + def inference_parameters(self) -> BaseInferenceParameters: + return self.model_config.inference_parameters + + def log_pre_generation(self) -> None: + emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] + logger.info(f"{emoji} Preparing {self.config.column_type} column generation") + logger.info(f" |-- column name: {self.config.name!r}") + logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") + if self.model_config.provider is None: + logger.info(f" |-- default model provider: {self._get_provider_name()!r}") + + def _get_provider_name(self) -> str: + model_alias = self.model_config.alias + provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) + return provider.name diff --git a/src/data_designer/engine/column_generators/generators/embedding.py b/src/data_designer/engine/column_generators/generators/embedding.py new file mode 100644 index 00000000..ed738e8f --- /dev/null +++ b/src/data_designer/engine/column_generators/generators/embedding.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from data_designer.config.column_configs import EmbeddingColumnConfig +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, + WithModelGeneration, +) +from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string +from data_designer.engine.resources.resource_provider import ResourceType + + +class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="embedding_cell_generator", + description="Generate embeddings for a text column.", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + def generate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + input_texts = parse_list_string(deserialized_record[self.config.target_column]) + embeddings = self.model.generate_text_embeddings(input_texts=input_texts) + + data[self.config.name] = { + "embeddings": embeddings, + "num_embeddings": len(embeddings), + "dimension": len(embeddings[0]) if len(embeddings) > 0 else 0, + } + return data diff --git a/src/data_designer/engine/column_generators/generators/image.py b/src/data_designer/engine/column_generators/generators/image.py new file mode 100644 index 00000000..f7cfba89 --- /dev/null +++ b/src/data_designer/engine/column_generators/generators/image.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from litellm.types.utils import ImageResponse + +from data_designer.config.column_configs import ImageGenerationColumnConfig +from data_designer.config.models import ModalityDataType +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, + WithModelGeneration, +) +from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering +from data_designer.engine.processing.utils import deserialize_json_values +from data_designer.engine.resources.resource_provider import ResourceType + + +class ImageCellGenerator( + WithModelGeneration, WithJinja2UserTemplateRendering, ColumnGenerator[ImageGenerationColumnConfig] +): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="image_cell_generator", + description="Generate images using a specified model.", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + def generate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + missing_columns = list(set(self.config.required_columns) - set(data.keys())) + if len(missing_columns) > 0: + error_msg = ( + f"There was an error preparing the Jinja2 expression template. " + f"The following columns {missing_columns} are missing!" + ) + raise ValueError(error_msg) + + self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) + prompt = self.render_template(deserialized_record) + image_response: ImageResponse = self.model.generate_image(prompt=prompt) + if self.model_config.inference_parameters.output_format == ModalityDataType.URL: + data[self.config.name] = image_response.data[0].url + else: + data[self.config.name] = image_response.data[0].b64_json + return data diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_completion.py similarity index 70% rename from src/data_designer/engine/column_generators/generators/llm_generators.py rename to src/data_designer/engine/column_generators/generators/llm_completion.py index ee0ab58a..8fae174b 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_completion.py @@ -10,43 +10,41 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP -from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, GenerationStrategy, GeneratorMetadata, + WithModelGeneration, ) from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, create_response_recipe, ) -from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.recipes.base import ResponseRecipe from data_designer.engine.processing.utils import deserialize_json_values from data_designer.engine.resources.resource_provider import ResourceType -DEFAULT_MAX_CONVERSATION_RESTARTS = 5 -DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +DEFAULT_MAX_CONVERSATION_RESTARTS = 5 +DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 -class WithLLMGeneration: +class WithCompletionGeneration(WithModelGeneration): @functools.cached_property - def model(self) -> ModelFacade: - return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) + def response_recipe(self) -> ResponseRecipe: + return create_response_recipe(self.config, self.model_config) - @functools.cached_property - def model_config(self) -> ModelConfig: - return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) + @property + def max_conversation_correction_steps(self) -> int: + return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - @functools.cached_property - def inference_parameters(self) -> InferenceParameters: - return self.model_config.inference_parameters + @property + def max_conversation_restarts(self) -> int: + return DEFAULT_MAX_CONVERSATION_RESTARTS @functools.cached_property def prompt_renderer(self) -> RecordBasedPromptRenderer: @@ -59,18 +57,6 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer: }, ) - @functools.cached_property - def response_recipe(self) -> ResponseRecipe: - return create_response_recipe(self.config, self.model_config) - - @property - def max_conversation_correction_steps(self) -> int: - return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - - @property - def max_conversation_restarts(self) -> int: - return DEFAULT_MAX_CONVERSATION_RESTARTS - def generate(self, data: dict) -> dict: deserialized_record = deserialize_json_values(data) @@ -96,7 +82,6 @@ def generate(self, data: dict) -> dict: max_correction_steps=self.max_conversation_correction_steps, max_conversation_restarts=self.max_conversation_restarts, purpose=f"running generation for column '{self.config.name}'", - **self.inference_parameters.generate_kwargs, ) data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response)) @@ -106,21 +91,8 @@ def generate(self, data: dict) -> dict: return data - def log_pre_generation(self) -> None: - emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] - logger.info(f"{emoji} Preparing {self.config.column_type} column generation") - logger.info(f" |-- column name: {self.config.name!r}") - logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") - if self.model_config.provider is None: - logger.info(f" |-- default model provider: {self._get_provider_name()!r}") - - def _get_provider_name(self) -> str: - model_alias = self.model_config.alias - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return provider.name - -class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]): +class LLMTextCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -131,7 +103,7 @@ def metadata() -> GeneratorMetadata: ) -class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]): +class LLMCodeCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -142,7 +114,7 @@ def metadata() -> GeneratorMetadata: ) -class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]): +class LLMStructuredCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -153,7 +125,7 @@ def metadata() -> GeneratorMetadata: ) -class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]): +class LLMJudgeCellGenerator(WithCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -163,10 +135,6 @@ def metadata() -> GeneratorMetadata: required_resources=[ResourceType.MODEL_REGISTRY], ) - @property - def max_conversation_correction_steps(self) -> int: - return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - @property def max_conversation_restarts(self) -> int: - return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS + return DEFAULT_MAX_CONVERSATION_RESTARTS * 2 diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 61b43753..3d000729 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -3,7 +3,9 @@ from data_designer.config.base import ConfigBase from data_designer.config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -12,8 +14,10 @@ ) from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator +from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.image import ImageCellGenerator +from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, @@ -40,11 +44,12 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig) registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig) registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig) + registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig) registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig) registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) - + registry.register(DataDesignerColumnType.IMAGE_GENERATION, ImageCellGenerator, ImageGenerationColumnConfig) if with_plugins: for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): registry.register( diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index e7060f82..2e30407c 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -17,8 +17,11 @@ ProcessorConfig, ProcessorType, ) -from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy -from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + WithModelGeneration, +) from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -169,7 +172,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR - if isinstance(generator, WithLLMGeneration): + if isinstance(generator, WithModelGeneration): max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) @@ -183,7 +186,7 @@ def _run_model_health_check_if_needed(self) -> bool: set(config.model_alias for config in self.llm_generated_column_configs) ) - def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None: + def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None: if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} " diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index 93ca0fd7..33c79797 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -9,9 +9,9 @@ from typing import Any from litellm.types.router import DeploymentTypedDict, LiteLLM_Params -from litellm.types.utils import ModelResponse +from litellm.types.utils import EmbeddingResponse, ImageResponse, ImageUsage, ModelResponse -from data_designer.config.models import ModelConfig, ModelProvider +from data_designer.config.models import GenerationType, ModelConfig, ModelProvider from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.errors import ( GenerationValidationFailureError, @@ -49,6 +49,10 @@ def model_name(self) -> str: def model_provider(self) -> ModelProvider: return self._model_provider_registry.get_provider(self._model_config.provider) + @property + def model_generation_type(self) -> GenerationType: + return self._model_config.generation_type + @property def model_provider_name(self) -> str: return self.model_provider.name @@ -64,13 +68,12 @@ def usage_stats(self) -> ModelUsageStats: def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse: logger.debug( f"Prompting model {self.model_name!r}...", - extra={"model": self.model_name, "messages": messages, "sensitive": True}, + extra={"model": self.model_name, "messages": messages}, ) response = None - if self.model_provider.extra_body: - kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + kwargs = self.consolidate_kwargs(**kwargs) try: - response = self._router.completion(self.model_name, messages, **kwargs) + response = self._router.completion(model=self.model_name, messages=messages, **kwargs) logger.debug( f"Received completion from model {self.model_name!r}", extra={ @@ -84,9 +87,71 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = except Exception as e: raise e finally: - if not skip_usage_tracking: + if not skip_usage_tracking and response is not None: self._track_usage(response) + def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: + # Remove purpose from kwargs to avoid passing it to the model + kwargs.pop("purpose", None) + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + return kwargs + + @catch_llm_exceptions + def generate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage_from_embedding(response) + + @catch_llm_exceptions + def generate_image(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> ImageResponse: + logger.debug( + f"Generating image with model {self.model_name!r}...", + extra={"model": self.model_name, "prompt": prompt}, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs) + logger.debug( + f"Received image from model {self.model_name!r}", + extra={"model": self.model_name, "response": response}, + ) + return response + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage_from_image(response) + @catch_llm_exceptions def generate( self, @@ -223,3 +288,29 @@ def _track_usage(self, response: ModelResponse | None) -> None: ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) + + def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if response.usage is not None and response.usage.prompt_tokens is not None: + self._usage_stats.extend( + token_usage=TokenUsageStats( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=0, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + + def _track_usage_from_image(self, response: ImageResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if response.usage is not None and isinstance(response.usage, ImageUsage): + self._usage_stats.extend( + token_usage=TokenUsageStats( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py index aafd8c80..91025684 100644 --- a/src/data_designer/engine/models/registry.py +++ b/src/data_designer/engine/models/registry.py @@ -5,7 +5,7 @@ import logging -from data_designer.config.models import ModelConfig +from data_designer.config.models import GenerationType, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches @@ -81,15 +81,30 @@ def run_health_check(self, model_aliases: set[str]) -> None: f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..." ) try: - model.generate( - prompt="Hello!", - parser=lambda x: x, - system_prompt="You are a helpful assistant.", - max_correction_steps=0, - max_conversation_restarts=0, - skip_usage_tracking=True, - purpose="running health checks", - ) + if model.model_generation_type == GenerationType.EMBEDDING: + model.generate_text_embeddings( + input_texts=["Hello!"], + skip_usage_tracking=True, + purpose="running health checks", + ) + elif model.model_generation_type == GenerationType.CHAT_COMPLETION: + model.generate( + prompt="Hello!", + parser=lambda x: x, + system_prompt="You are a helpful assistant.", + max_correction_steps=0, + max_conversation_restarts=0, + skip_usage_tracking=True, + purpose="running health checks", + ) + elif model.model_generation_type == GenerationType.IMAGE_GENERATION: + model.generate_image( + prompt="Generate a simple pixel", + skip_usage_tracking=True, + purpose="running health checks", + ) + else: + raise ValueError(f"Unsupported generation type: {model.model_generation_type}") logger.info(" |-- ✅ Passed!") except Exception as e: logger.error(" |-- ❌ Failed!") diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py index 3579b3bd..5d42c40e 100644 --- a/src/data_designer/engine/processing/utils.py +++ b/src/data_designer/engine/processing/utils.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import ast import json import logging +import re from typing import Any, TypeVar, Union, overload import pandas as pd @@ -100,6 +102,42 @@ def deserialize_json_values(data): return data +def parse_list_string(text: str) -> list[str]: + """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas.""" + text = text.strip() + + # Try JSON first + try: + list_obj = json.loads(text) + if isinstance(list_obj, list): + return _clean_whitespace(list_obj) + except json.JSONDecodeError: + pass + + # Remove trailing commas before closing brackets (common in JSON-like strings) + text_cleaned = re.sub(r",\s*]", "]", text) + text_cleaned = re.sub(r",\s*}", "}", text_cleaned) + + # Try JSON again with cleaned text + try: + return _clean_whitespace(json.loads(text_cleaned)) + except json.JSONDecodeError: + pass + + # Try Python literal eval (handles single quotes) + try: + return _clean_whitespace(ast.literal_eval(text_cleaned)) + except (ValueError, SyntaxError): + pass + + # If all else fails, return the original text + return [text.strip()] + + +def _clean_whitespace(texts: list[str]) -> list[str]: + return [text.strip() for text in texts] + + def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None: joined_columns = set() for df in datasets: diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 8cd8eb92..e8c6091c 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -6,7 +6,9 @@ from ..config.analysis.column_profilers import JudgeScoreProfilerConfig from ..config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, + ImageGenerationColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -22,8 +24,12 @@ from ..config.dataset_builders import BuildStage from ..config.datastore import DatastoreSettings from ..config.models import ( + CompletionInferenceParameters, + EmbeddingInferenceParameters, + GenerationType, ImageContext, ImageFormat, + ImageGenerationInferenceParameters, InferenceParameters, ManualDistribution, ManualDistributionParams, @@ -78,25 +84,31 @@ "BernoulliMixtureSamplerParams", "BernoulliSamplerParams", "BinomialSamplerParams", + "BuildStage", "CategorySamplerParams", "CodeLang", "CodeValidatorParams", "ColumnInequalityConstraint", + "CompletionInferenceParameters", "configure_logging", "DataDesignerColumnType", "DataDesignerConfig", "DataDesignerConfigBuilder", - "BuildStage", "DatastoreSeedDatasetReference", "DatastoreSettings", "DatetimeSamplerParams", "DropColumnsProcessorConfig", + "EmbeddingColumnConfig", + "EmbeddingInferenceParameters", "ExpressionColumnConfig", "GaussianSamplerParams", + "GenerationType", "IndexRange", "InfoType", "ImageContext", "ImageFormat", + "ImageGenerationColumnConfig", + "ImageGenerationInferenceParameters", "InferenceParameters", "JudgeScoreProfilerConfig", "LLMCodeColumnConfig", diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 758e837e..66a06347 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -9,16 +9,16 @@ from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository from data_designer.cli.services.model_service import ModelService from data_designer.cli.services.provider_service import ProviderService -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import CompletionInferenceParameters, ModelConfig, ModelProvider @pytest.fixture -def stub_inference_parameters() -> InferenceParameters: - return InferenceParameters(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4) +def stub_inference_parameters() -> CompletionInferenceParameters: + return CompletionInferenceParameters(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4) @pytest.fixture -def stub_model_configs(stub_inference_parameters: InferenceParameters) -> list[ModelConfig]: +def stub_model_configs(stub_inference_parameters: CompletionInferenceParameters) -> list[ModelConfig]: return [ ModelConfig( alias="test-alias-1", @@ -41,7 +41,7 @@ def stub_new_model_config() -> ModelConfig: alias="test-alias-3", model="test-model-3", provider="test-provider-1", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.7, top_p=0.9, max_tokens=2048, diff --git a/tests/cli/controllers/test_model_controller.py b/tests/cli/controllers/test_model_controller.py index b630b04a..4f718ca4 100644 --- a/tests/cli/controllers/test_model_controller.py +++ b/tests/cli/controllers/test_model_controller.py @@ -9,7 +9,7 @@ from data_designer.cli.controllers.model_controller import ModelController from data_designer.cli.repositories.model_repository import ModelConfigRegistry from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig @pytest.fixture @@ -141,7 +141,7 @@ def test_run_updates_model( alias="test-alias-1-updated", model="test-model-1-updated", provider="test-provider-1", - inference_parameters=InferenceParameters(temperature=0.8, top_p=0.95, max_tokens=1024), + inference_parameters=CompletionInferenceParameters(temperature=0.8, top_p=0.95, max_tokens=1024), ) mock_builder = MagicMock() diff --git a/tests/cli/repositories/test_model_repository.py b/tests/cli/repositories/test_model_repository.py index 01884b5c..624cd360 100644 --- a/tests/cli/repositories/test_model_repository.py +++ b/tests/cli/repositories/test_model_repository.py @@ -21,7 +21,9 @@ def test_load_does_not_exist(): def test_load_exists(tmp_path: Path, stub_model_configs: list[ModelConfig]): model_configs_file_path = tmp_path / MODEL_CONFIGS_FILE_NAME - save_config_file(model_configs_file_path, {"model_configs": [mc.model_dump() for mc in stub_model_configs]}) + save_config_file( + model_configs_file_path, {"model_configs": [mc.model_dump(mode="json") for mc in stub_model_configs]} + ) repository = ModelRepository(tmp_path) assert repository.load() is not None assert repository.load().model_configs == stub_model_configs diff --git a/tests/cli/services/test_model_service.py b/tests/cli/services/test_model_service.py index 1d9bf5aa..4287eee8 100644 --- a/tests/cli/services/test_model_service.py +++ b/tests/cli/services/test_model_service.py @@ -7,7 +7,7 @@ from data_designer.cli.repositories.model_repository import ModelRepository from data_designer.cli.services.model_service import ModelService -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig def test_list_all(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]): @@ -30,7 +30,9 @@ def test_add( assert stub_model_service.list_all() == stub_model_configs + [stub_new_model_config] -def test_add_duplicate_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters): +def test_add_duplicate_alias( + stub_model_service: ModelService, stub_inference_parameters: CompletionInferenceParameters +): """Test adding a model with an alias that already exists.""" duplicate_model = ModelConfig( alias="test-alias-1", @@ -61,7 +63,9 @@ def test_update_nonexistent_model(stub_model_service: ModelService, stub_new_mod stub_model_service.update("nonexistent", stub_new_model_config) -def test_update_to_existing_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters): +def test_update_to_existing_alias( + stub_model_service: ModelService, stub_inference_parameters: CompletionInferenceParameters +): """Test updating a model to an alias that already exists.""" updated_model = ModelConfig( alias="test-alias-2", # Already exists diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index f0f5c51a..2e74695f 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -49,6 +49,8 @@ def test_data_designer_column_type_get_display_order(): DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, + DataDesignerColumnType.IMAGE_GENERATION, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] @@ -59,6 +61,8 @@ def test_data_designer_column_type_is_llm_generated(): assert column_type_is_llm_generated(DataDesignerColumnType.LLM_CODE) assert column_type_is_llm_generated(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_is_llm_generated(DataDesignerColumnType.LLM_JUDGE) + assert column_type_is_llm_generated(DataDesignerColumnType.EMBEDDING) + assert column_type_is_llm_generated(DataDesignerColumnType.IMAGE_GENERATION) assert not column_type_is_llm_generated(DataDesignerColumnType.SAMPLER) assert not column_type_is_llm_generated(DataDesignerColumnType.VALIDATION) assert not column_type_is_llm_generated(DataDesignerColumnType.EXPRESSION) @@ -72,6 +76,8 @@ def test_data_designer_column_type_is_in_dag(): assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) + assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) + assert column_type_used_in_execution_dag(DataDesignerColumnType.IMAGE_GENERATION) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 337d934e..57741e59 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -26,7 +26,7 @@ from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy @@ -54,7 +54,7 @@ def stub_data_designer_builder(stub_data_designer_builder_config_str): def test_loading_model_configs_in_constructor(stub_model_configs): - stub_model_configs_dict = [mc.model_dump() for mc in stub_model_configs] + stub_model_configs_dict = [mc.model_dump(mode="json") for mc in stub_model_configs] # test loading model configs from a list builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) assert builder.model_configs == stub_model_configs @@ -670,7 +670,7 @@ def test_add_model_config(stub_empty_builder): new_model_config = ModelConfig( alias="new-model", model="openai/gpt-4", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.7, top_p=0.95, max_tokens=1024, @@ -691,7 +691,7 @@ def test_add_model_config(stub_empty_builder): alias="provider-model", model="anthropic/claude-3", provider="anthropic", - inference_parameters=InferenceParameters(temperature=0.8), + inference_parameters=CompletionInferenceParameters(temperature=0.8), ) stub_empty_builder.add_model_config(provider_model_config) @@ -717,7 +717,7 @@ def test_add_model_config_duplicate_alias(stub_empty_builder): duplicate_model_config = ModelConfig( alias="stub-model", model="different/model", - inference_parameters=InferenceParameters(temperature=0.5), + inference_parameters=CompletionInferenceParameters(temperature=0.5), ) with pytest.raises( @@ -733,12 +733,12 @@ def test_delete_model_config(stub_empty_builder): model_config_1 = ModelConfig( alias="model-to-delete", model="model/delete", - inference_parameters=InferenceParameters(temperature=0.5), + inference_parameters=CompletionInferenceParameters(temperature=0.5), ) model_config_2 = ModelConfig( alias="model-to-keep", model="model/keep", - inference_parameters=InferenceParameters(temperature=0.6), + inference_parameters=CompletionInferenceParameters(temperature=0.6), ) stub_empty_builder.add_model_config(model_config_1) stub_empty_builder.add_model_config(model_config_2) diff --git a/tests/config/test_default_model_settings.py b/tests/config/test_default_model_settings.py index 222bb410..8f389a69 100644 --- a/tests/config/test_default_model_settings.py +++ b/tests/config/test_default_model_settings.py @@ -18,20 +18,20 @@ get_default_providers, resolve_seed_default_model_settings, ) -from data_designer.config.models import InferenceParameters +from data_designer.config.models import CompletionInferenceParameters from data_designer.config.utils.visualization import get_nvidia_api_key, get_openai_api_key def test_get_default_inference_parameters(): - assert get_default_inference_parameters("text") == InferenceParameters( + assert get_default_inference_parameters("text") == CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) - assert get_default_inference_parameters("reasoning") == InferenceParameters( + assert get_default_inference_parameters("reasoning") == CompletionInferenceParameters( temperature=0.35, top_p=0.95, ) - assert get_default_inference_parameters("vision") == InferenceParameters( + assert get_default_inference_parameters("vision") == CompletionInferenceParameters( temperature=0.85, top_p=0.95, ) diff --git a/tests/config/test_models.py b/tests/config/test_models.py index 9ccda6d5..40f6afe9 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -11,9 +11,12 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( + CompletionInferenceParameters, + EmbeddingInferenceParameters, + GenerationType, ImageContext, ImageFormat, - InferenceParameters, + ImageGenerationInferenceParameters, ManualDistribution, ManualDistributionParams, ModalityDataType, @@ -46,13 +49,13 @@ def test_image_context_validate_image_format(): def test_inference_parameters_default_construction(): - empty_inference_parameters = InferenceParameters() + empty_inference_parameters = CompletionInferenceParameters() assert empty_inference_parameters.generate_kwargs == {} assert empty_inference_parameters.max_parallel_requests == 4 def test_inference_parameters_generate_kwargs(): - assert InferenceParameters( + assert CompletionInferenceParameters( temperature=0.95, top_p=0.95, max_tokens=100, @@ -67,9 +70,9 @@ def test_inference_parameters_generate_kwargs(): "extra_body": {"reasoning_effort": "high"}, } - assert InferenceParameters().generate_kwargs == {} + assert CompletionInferenceParameters().generate_kwargs == {} - inference_parameters_kwargs = InferenceParameters( + inference_parameters_kwargs = CompletionInferenceParameters( temperature=UniformDistribution(params=UniformDistributionParams(low=0.0, high=1.0)), top_p=ManualDistribution(params=ManualDistributionParams(values=[0.0, 1.0], weights=[0.5, 0.5])), ).generate_kwargs @@ -131,32 +134,38 @@ def test_inference_parameters_temperature_validation(): # All temp values provide in a manual destribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters( + CompletionInferenceParameters( temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.5], weights=[0.5, 0.5])) ) # High and low values of uniform distribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5))) + CompletionInferenceParameters( + temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5)) + ) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0))) + CompletionInferenceParameters( + temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0)) + ) # Static values should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=3.0) + CompletionInferenceParameters(temperature=3.0) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=-1.0) + CompletionInferenceParameters(temperature=-1.0) # Valid temperature values shouldn't raise validation errors try: - InferenceParameters(temperature=0.1) - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0))) - InferenceParameters( + CompletionInferenceParameters(temperature=0.1) + CompletionInferenceParameters( + temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0)) + ) + CompletionInferenceParameters( temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.0], weights=[0.5, 0.5])) ) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters temperature validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters temperature validation") def test_generation_parameters_top_p_validation(): @@ -164,31 +173,31 @@ def test_generation_parameters_top_p_validation(): # All top_p values provide in a manual destribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters( + CompletionInferenceParameters( top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.5], weights=[0.5, 0.5])) ) # High and low values of uniform distribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5))) + CompletionInferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5))) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0))) + CompletionInferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0))) # Static values should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=1.5) + CompletionInferenceParameters(top_p=1.5) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=-0.1) + CompletionInferenceParameters(top_p=-0.1) # Valid top_p values shouldn't raise validation errors try: - InferenceParameters(top_p=0.1) - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0))) - InferenceParameters( + CompletionInferenceParameters(top_p=0.1) + CompletionInferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0))) + CompletionInferenceParameters( top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.0], weights=[0.5, 0.5])) ) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters top_p validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters top_p validation") def test_generation_parameters_max_tokens_validation(): @@ -196,15 +205,15 @@ def test_generation_parameters_max_tokens_validation(): ValidationError, match="Input should be greater than or equal to 1", ): - InferenceParameters(max_tokens=0) + CompletionInferenceParameters(max_tokens=0) # Valid max_tokens values shouldn't raise validation errors try: - InferenceParameters(max_tokens=128_000) - InferenceParameters(max_tokens=4096) - InferenceParameters(max_tokens=1) + CompletionInferenceParameters(max_tokens=128_000) + CompletionInferenceParameters(max_tokens=4096) + CompletionInferenceParameters(max_tokens=1) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters max_tokens validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters max_tokens validation") def test_load_model_configs(): @@ -212,7 +221,7 @@ def test_load_model_configs(): ModelConfig(alias="test", model="test"), ModelConfig(alias="test2", model="test2"), ] - stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs] + stub_model_configs_dict_list = [mc.model_dump(mode="json") for mc in stub_model_configs] assert load_model_configs([]) == [] assert load_model_configs(stub_model_configs) == stub_model_configs @@ -248,6 +257,48 @@ def test_load_model_configs(): load_model_configs(tmp_file.name) -def test_model_config_default_construction(): +def test_model_config_construction(): + # test default construction model_config = ModelConfig(alias="test", model="test") - assert model_config.inference_parameters == InferenceParameters() + assert model_config.inference_parameters == CompletionInferenceParameters() + assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + # test construction with completion inference parameters + completion_params = CompletionInferenceParameters(temperature=0.5, top_p=0.5, max_tokens=100) + model_config = ModelConfig(alias="test", model="test", inference_parameters=completion_params) + assert model_config.inference_parameters == completion_params + assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + # test construction with embedding inference parameters + embedding_params = EmbeddingInferenceParameters(dimensions=100) + model_config = ModelConfig( + alias="test", model="test", generation_type=GenerationType.EMBEDDING, inference_parameters=embedding_params + ) + assert model_config.inference_parameters == embedding_params + assert model_config.generation_type == GenerationType.EMBEDDING + + # test construction with image generation inference parameters + image_generation_params = ImageGenerationInferenceParameters(size="1024x1024", quality="standard") + model_config = ModelConfig( + alias="test", + model="test", + generation_type=GenerationType.IMAGE_GENERATION, + inference_parameters=image_generation_params, + ) + assert model_config.inference_parameters == image_generation_params + assert model_config.generation_type == GenerationType.IMAGE_GENERATION + + +def test_model_config_invalid_generation_type(): + with pytest.raises(ValidationError, match="Input should be"): + ModelConfig(alias="test", model="test", generation_type="invalid_generation_type") + with pytest.raises( + ValidationError, + match="Inference parameters must be an instance of 'EmbeddingInferenceParameters' when generation_type is 'embedding'", + ): + ModelConfig( + alias="test", + model="test", + generation_type=GenerationType.EMBEDDING, + inference_parameters=CompletionInferenceParameters(), + ) diff --git a/tests/conftest.py b/tests/conftest.py index 31dc0057..46b5d318 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import CompletionInferenceParameters, ModelConfig, ModelProvider @pytest.fixture @@ -135,7 +135,7 @@ def stub_model_configs() -> list[ModelConfig]: ModelConfig( alias="stub-model", model="stub-model", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.9, top_p=0.9, max_tokens=2048, diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py similarity index 92% rename from tests/engine/column_generators/generators/test_llm_generators.py rename to tests/engine/column_generators/generators/test_llm_completion_generators.py index 259f3a08..0b787b7e 100644 --- a/tests/engine/column_generators/generators/test_llm_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -11,7 +11,7 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS, DEFAULT_MAX_CONVERSATION_RESTARTS, REASONING_TRACE_COLUMN_POSTFIX, @@ -94,7 +94,7 @@ def test_generate_method(): assert call_args[1]["multi_modal_context"] is None -@patch("data_designer.engine.column_generators.generators.llm_generators.logger", autospec=True) +@patch("data_designer.engine.column_generators.generators.base.logger", autospec=True) def test_log_pre_generation(mock_logger): generator, mock_resource_provider, _, mock_model_config, _, _, _ = _create_generator_with_mocks() mock_model_config.model_dump_json.return_value = '{"test": "config"}' @@ -259,20 +259,3 @@ def test_generate_with_json_deserialization(): result = generator.generate(data) assert result["test_column"] == {"result": "json_output"} - - -def test_generate_with_inference_parameters(): - generator, _, mock_model, _, mock_inference_params, mock_prompt_renderer, mock_response_recipe = ( - _create_generator_with_mocks() - ) - - mock_inference_params.generate_kwargs = {"temperature": 0.7, "max_tokens": 100} - _setup_generate_mocks(mock_prompt_renderer, mock_response_recipe, mock_model) - - data = {"input": "test_input"} - generator.generate(data) - - call_args = mock_model.generate.call_args - assert call_args[1]["temperature"] == 0.7 - assert call_args[1]["max_tokens"] == 100 - assert call_args[1]["purpose"] == "running generation for column 'test_column'" diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index f70b0d90..0d325937 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -3,7 +3,7 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, diff --git a/tests/engine/models/conftest.py b/tests/engine/models/conftest.py index 95e6941f..7edcd073 100644 --- a/tests/engine/models/conftest.py +++ b/tests/engine/models/conftest.py @@ -5,7 +5,7 @@ import pytest -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.registry import ModelRegistry, create_model_registry from data_designer.engine.secret_resolver import SecretsFileResolver @@ -38,7 +38,7 @@ def stub_model_configs() -> list[ModelConfig]: alias="stub-text", model="stub-model-text", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ), @@ -46,7 +46,7 @@ def stub_model_configs() -> list[ModelConfig]: alias="stub-reasoning", model="stub-model-reasoning", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ), diff --git a/tests/engine/models/test_facade.py b/tests/engine/models/test_facade.py index 4fa73d9a..8765d0ab 100644 --- a/tests/engine/models/test_facade.py +++ b/tests/engine/models/test_facade.py @@ -4,7 +4,7 @@ from collections import namedtuple from unittest.mock import patch -from litellm.types.utils import Choices, Message, ModelResponse +from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse import pytest from data_designer.engine.models.errors import ModelGenerationValidationFailureError @@ -30,10 +30,20 @@ def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_prov @pytest.fixture -def stub_expected_response(): +def stub_completion_messages(): + return [{"role": "user", "content": "test"}] + + +@pytest.fixture +def stub_expected_completion_response(): return ModelResponse(choices=Choices(message=Message(content="Test response"))) +@pytest.fixture +def stub_expected_embedding_response(): + return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) + + @pytest.mark.parametrize( "max_correction_steps,max_conversation_restarts,total_calls", [ @@ -105,6 +115,24 @@ def test_usage_stats_property(stub_model_facade): assert hasattr(stub_model_facade.usage_stats, "model_dump") +def test_consolidate_kwargs(stub_model_configs, stub_model_facade): + # Model config generate kwargs are used as base, and purpose is removed + result = stub_model_facade.consolidate_kwargs(purpose="test") + assert result == stub_model_configs[0].inference_parameters.generate_kwargs + + # kwargs overrides model config generate kwargs + result = stub_model_facade.consolidate_kwargs(temperature=0.01, purpose="test") + assert result == {**stub_model_configs[0].inference_parameters.generate_kwargs, "temperature": 0.01} + + # Provider extra_body overrides all other kwargs + stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"} + result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}, purpose="test") + assert result == { + **stub_model_configs[0].inference_parameters.generate_kwargs, + "extra_body": {"foo_provider": "bar_provider", "foo": "bar"}, + } + + @pytest.mark.parametrize( "skip_usage_tracking", [ @@ -112,63 +140,85 @@ def test_usage_stats_property(stub_model_facade): True, ], ) -def test_completion_success(stub_model_facade, stub_expected_response, skip_usage_tracking): - stub_model_facade._router.completion = lambda model_name, messages, **kwargs: stub_expected_response - - messages = [{"role": "user", "content": "test"}] - result = stub_model_facade.completion(messages, skip_usage_tracking=skip_usage_tracking) - - assert result == stub_expected_response - - -def test_completion_with_exception(stub_model_facade): - def raise_exception(*args, **kwargs): - raise Exception("Router error") +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_success( + mock_router_completion, + stub_completion_messages, + stub_model_configs, + stub_model_facade, + stub_expected_completion_response, + skip_usage_tracking, +): + mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response + result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) + assert result == stub_expected_completion_response + assert mock_router_completion.call_count == 1 + assert mock_router_completion.call_args[1] == { + "model": "stub-model-text", + "messages": stub_completion_messages, + **stub_model_configs[0].inference_parameters.generate_kwargs, + } - stub_model_facade._router.completion = raise_exception - messages = [{"role": "user", "content": "test"}] +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_with_exception(mock_router_completion, stub_completion_messages, stub_model_facade): + mock_router_completion.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): - stub_model_facade.completion(messages) + stub_model_facade.completion(stub_completion_messages) -def test_completion_with_kwargs(stub_model_facade, stub_expected_response): +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_with_kwargs( + mock_router_completion, + stub_completion_messages, + stub_model_configs, + stub_model_facade, + stub_expected_completion_response, +): captured_kwargs = {} - def mock_completion(model_name, messages, **kwargs): + def mock_completion(self, model, messages, **kwargs): captured_kwargs.update(kwargs) - return stub_expected_response + return stub_expected_completion_response - stub_model_facade._router.completion = mock_completion + mock_router_completion.side_effect = mock_completion - messages = [{"role": "user", "content": "test"}] kwargs = {"temperature": 0.7, "max_tokens": 100} - result = stub_model_facade.completion(messages, **kwargs) + result = stub_model_facade.completion(stub_completion_messages, **kwargs) - assert result == stub_expected_response - assert captured_kwargs == kwargs + assert result == stub_expected_completion_response + # completion kwargs overrides model config generate kwargs + assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} -@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) -def test_completion_with_extra_body(mock_router_completion, stub_model_facade): - messages = [{"role": "user", "content": "test"}] - - # completion call has no extra body argument and provider has no extra body - _ = stub_model_facade.completion(messages) - assert len(mock_router_completion.call_args) == 2 - assert mock_router_completion.call_args[0][1] == "stub-model-text" - assert mock_router_completion.call_args[0][2] == messages - - # completion call has no extra body argument and provider has extra body. - # Should pull extra body from model provider - custom_extra_body = {"some_custom_key": "some_custom_value"} - stub_model_facade.model_provider.extra_body = custom_extra_body - _ = stub_model_facade.completion(messages) - assert mock_router_completion.call_args[1] == {"extra_body": custom_extra_body} - - # completion call has extra body argument and provider has extra body. - # Should merge the two with provider extra body taking precedence - completion_extra_body = {"some_completion_key": "some_completion_value", "some_custom_key": "some_different_value"} - _ = stub_model_facade.completion(messages, extra_body=completion_extra_body) - assert mock_router_completion.call_args[1] == {"extra_body": {**completion_extra_body, **custom_extra_body}} +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_success(mock_router_embedding, stub_model_facade, stub_expected_embedding_response): + mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response + input_texts = ["test1", "test2"] + result = stub_model_facade.generate_text_embeddings(input_texts) + assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + + +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_with_exception(mock_router_embedding, stub_model_facade): + mock_router_embedding.side_effect = Exception("Router error") + + with pytest.raises(Exception, match="Router error"): + stub_model_facade.generate_text_embeddings(["test1", "test2"]) + + +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_with_kwargs( + mock_router_embedding, stub_model_configs, stub_model_facade, stub_expected_embedding_response +): + captured_kwargs = {} + + def mock_embedding(self, model, input, **kwargs): + captured_kwargs.update(kwargs) + return stub_expected_embedding_response + + mock_router_embedding.side_effect = mock_embedding + kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"} + _ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs) + assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} diff --git a/tests/engine/models/test_model_registry.py b/tests/engine/models/test_model_registry.py index 571b9605..83e3b650 100644 --- a/tests/engine/models/test_model_registry.py +++ b/tests/engine/models/test_model_registry.py @@ -6,7 +6,7 @@ from litellm import AuthenticationError import pytest -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import CompletionInferenceParameters, ModelConfig from data_designer.engine.models.errors import ModelAuthenticationError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry, create_model_registry @@ -24,7 +24,7 @@ def stub_new_model_config(): alias="stub-vision", model="stub-model-vision", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=CompletionInferenceParameters( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ) @@ -36,7 +36,7 @@ def stub_no_usage_config(): alias="no-usage", model="no-usage-model", provider="stub-model-provider", - inference_parameters=InferenceParameters(), + inference_parameters=CompletionInferenceParameters(), ) diff --git a/tests/engine/processing/test_utils.py b/tests/engine/processing/test_utils.py index a41e0ec2..dec0fe6a 100644 --- a/tests/engine/processing/test_utils.py +++ b/tests/engine/processing/test_utils.py @@ -9,6 +9,7 @@ from data_designer.engine.processing.utils import ( concat_datasets, deserialize_json_values, + parse_list_string, ) @@ -116,3 +117,19 @@ def test_concat_datasets_logging(mock_logger, stub_sample_dataframes): def test_deserialize_json_values_scenarios(test_case, input_data, expected_result): result = deserialize_json_values(input_data) assert result == expected_result + + +@pytest.mark.parametrize( + "input_string,expected_result", + [ + ('["a", "b", "c"]', ["a", "b", "c"]), # valid stringified json array + ('[" a ", " b", "c "]', ["a", "b", "c"]), # valid stringified json array with whitespace + ('["a", "b", "c",]', ["a", "b", "c"]), # valid stringified json array with trailing comma + ("['a', 'b', 'c']", ["a", "b", "c"]), # valid python-style list with single quotes + ("['a', 'b', 'c', ]", ["a", "b", "c"]), # valid python-style list with trailing comma + ("simple string ", ["simple string"]), # simple string with whitespace + ], +) +def test_parse_list_string_scenarios(input_string, expected_result): + result = parse_list_string(input_string) + assert result == expected_result diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py index 89f8388a..d810bba3 100644 --- a/tests/essentials/test_init.py +++ b/tests/essentials/test_init.py @@ -17,14 +17,17 @@ CodeLang, CodeValidatorParams, ColumnInequalityConstraint, + CompletionInferenceParameters, DataDesignerColumnType, DataDesignerConfig, DataDesignerConfigBuilder, DatastoreSeedDatasetReference, DatastoreSettings, DatetimeSamplerParams, + EmbeddingInferenceParameters, ExpressionColumnConfig, GaussianSamplerParams, + GenerationType, ImageContext, ImageFormat, InferenceParameters, @@ -109,6 +112,9 @@ def test_model_config_imports(): assert ImageContext is not None assert ImageFormat is not None assert InferenceParameters is not None + assert CompletionInferenceParameters is not None + assert EmbeddingInferenceParameters is not None + assert GenerationType is not None assert ManualDistribution is not None assert ManualDistributionParams is not None assert Modality is not None @@ -232,6 +238,7 @@ def test_all_contains_column_configs(): assert "Score" in __all__ assert "SeedDatasetColumnConfig" in __all__ assert "ValidationColumnConfig" in __all__ + assert "EmbeddingColumnConfig" in __all__ def test_all_contains_sampler_params(): @@ -250,6 +257,8 @@ def test_all_contains_sampler_params(): assert "TimeDeltaSamplerParams" in __all__ assert "UniformSamplerParams" in __all__ assert "UUIDSamplerParams" in __all__ + assert "PersonFromFakerSamplerParams" in __all__ + assert "ProcessorType" in __all__ def test_all_contains_constraints(): @@ -263,6 +272,9 @@ def test_all_contains_model_configs(): assert "ImageContext" in __all__ assert "ImageFormat" in __all__ assert "InferenceParameters" in __all__ + assert "CompletionInferenceParameters" in __all__ + assert "EmbeddingInferenceParameters" in __all__ + assert "GenerationType" in __all__ assert "ManualDistribution" in __all__ assert "ManualDistributionParams" in __all__ assert "Modality" in __all__