Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/data_designer/config/column_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,64 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
"""

column_type: Literal["seed-dataset"] = "seed-dataset"


class EmbeddingColumnConfig(SingleColumnConfig):
"""Configuration for embedding generation columns.

Embedding columns generate embeddings for text input using a specified model.

Attributes:
column_type: Discriminator field, always "embedding" for this configuration type.
target_column: The column to generate embeddings for. The column could be a single text string or a list of text strings in stringified JSON format.
If it is a list of text strings in stringified JSON format, the embeddings will be generated for each text string.
model_alias: The model to use for embedding generation.
"""

column_type: Literal["embedding"] = "embedding"
target_column: str
model_alias: str

@property
def required_columns(self) -> list[str]:
return [self.target_column]


class ImageGenerationColumnConfig(SingleColumnConfig):
"""Configuration for image generation columns.

Image columns generate images using a specified model.

Attributes:
column_type: Discriminator field, always "image-generation" for this configuration type.
prompt: Prompt template for image generation. Supports Jinja2 templating to
reference other columns (e.g., "Generate an image of a {{ character_name }}").
Must be a valid Jinja2 template.
model_alias: The model to use for image generation.
"""

column_type: Literal["image-generation"] = "image-generation"
prompt: str
model_alias: str

@property
def required_columns(self) -> list[str]:
"""Get columns referenced in the prompt template.

Returns:
List of unique column names referenced in Jinja2 templates.
"""
return list(get_prompt_template_keywords(self.prompt))

@model_validator(mode="after")
def assert_prompt_valid_jinja(self) -> Self:
"""Validate that prompt is a valid Jinja2 template.

Returns:
The validated instance.

Raises:
InvalidConfigError: If prompt contains invalid Jinja2 syntax.
"""
assert_valid_jinja2_template(self.prompt)
return self
16 changes: 16 additions & 0 deletions src/data_designer/config/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from ..plugin_manager import PluginManager
from .column_configs import (
EmbeddingColumnConfig,
ExpressionColumnConfig,
ImageGenerationColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
Expand All @@ -31,6 +33,8 @@
SamplerColumnConfig,
SeedDatasetColumnConfig,
ValidationColumnConfig,
EmbeddingColumnConfig,
ImageGenerationColumnConfig,
]
ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT)

Expand All @@ -50,6 +54,8 @@
DataDesignerColumnType.SEED_DATASET: "🌱",
DataDesignerColumnType.SAMPLER: "🎲",
DataDesignerColumnType.VALIDATION: "🔍",
DataDesignerColumnType.EMBEDDING: "🧬",
DataDesignerColumnType.IMAGE_GENERATION: "🖼️",
}
COLUMN_TYPE_EMOJI_MAP.update(
{DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()}
Expand All @@ -66,6 +72,8 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EMBEDDING,
DataDesignerColumnType.IMAGE_GENERATION,
}
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
return column_type in dag_column_types
Expand All @@ -79,6 +87,8 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
DataDesignerColumnType.IMAGE_GENERATION,
}
llm_generated_column_types.update(
plugin_manager.get_plugin_column_types(
Expand Down Expand Up @@ -117,6 +127,10 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
if column_type == DataDesignerColumnType.SEED_DATASET:
return SeedDatasetColumnConfig(name=name, **kwargs)
if column_type == DataDesignerColumnType.EMBEDDING:
return EmbeddingColumnConfig(name=name, **kwargs)
if column_type == DataDesignerColumnType.IMAGE_GENERATION:
return ImageGenerationColumnConfig(name=name, **kwargs)
if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value):
return plugin.config_cls(name=name, **kwargs)
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
Expand All @@ -131,6 +145,8 @@ def get_column_display_order() -> list[DataDesignerColumnType]:
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
DataDesignerColumnType.IMAGE_GENERATION,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
]
Expand Down
23 changes: 13 additions & 10 deletions src/data_designer/config/default_model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Any, Literal, Optional

from .models import InferenceParameters, ModelConfig, ModelProvider
from .models import CompletionInferenceParameters, ModelConfig, ModelProvider
from .utils.constants import (
MANAGED_ASSETS_PATH,
MODEL_CONFIGS_FILE_PATH,
Expand All @@ -21,28 +21,30 @@
logger = logging.getLogger(__name__)


def get_default_text_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
def get_default_text_alias_inference_parameters() -> CompletionInferenceParameters:
return CompletionInferenceParameters(
temperature=0.85,
top_p=0.95,
)


def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
def get_default_reasoning_alias_inference_parameters() -> CompletionInferenceParameters:
return CompletionInferenceParameters(
temperature=0.35,
top_p=0.95,
)


def get_default_vision_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
def get_default_vision_alias_inference_parameters() -> CompletionInferenceParameters:
return CompletionInferenceParameters(
temperature=0.85,
top_p=0.95,
)


def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
def get_default_inference_parameters(
model_alias: Literal["text", "reasoning", "vision"],
) -> CompletionInferenceParameters:
if model_alias == "reasoning":
return get_default_reasoning_alias_inference_parameters()
elif model_alias == "vision":
Expand Down Expand Up @@ -103,15 +105,16 @@ def resolve_seed_default_model_settings() -> None:
f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
)
save_config_file(
MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
MODEL_CONFIGS_FILE_PATH,
{"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
)

if not MODEL_PROVIDERS_FILE_PATH.exists():
logger.debug(
f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
)
save_config_file(
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
)

if not MANAGED_ASSETS_PATH.exists():
Expand Down
107 changes: 96 additions & 11 deletions src/data_designer/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
import logging
from pathlib import Path
from typing import Any, Generic, List, Optional, TypeVar, Union
from typing import Any, Generic, List, Literal, Optional, TypeVar, Union

import numpy as np
from pydantic import BaseModel, Field, model_validator
Expand Down Expand Up @@ -136,17 +136,29 @@ def sample(self) -> float:
DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]


class InferenceParameters(ConfigBase):
temperature: Optional[Union[float, DistributionT]] = None
top_p: Optional[Union[float, DistributionT]] = None
max_tokens: Optional[int] = Field(default=None, ge=1)
class BaseInferenceParameters(ConfigBase, ABC):
max_parallel_requests: int = Field(default=4, ge=1)
timeout: Optional[int] = Field(default=None, ge=1)
extra_body: Optional[dict[str, Any]] = None

@property
def generate_kwargs(self) -> dict[str, Union[float, int]]:
def generate_kwargs(self) -> dict[str, Any]:
result = {}
if self.timeout is not None:
result["timeout"] = self.timeout
if self.extra_body is not None and self.extra_body != {}:
result["extra_body"] = self.extra_body
return result


class CompletionInferenceParameters(BaseInferenceParameters):
temperature: Optional[Union[float, DistributionT]] = None
top_p: Optional[Union[float, DistributionT]] = None
max_tokens: Optional[int] = Field(default=None, ge=1)

@property
def generate_kwargs(self) -> dict[str, Any]:
result = super().generate_kwargs
if self.temperature is not None:
result["temperature"] = (
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
Expand All @@ -155,10 +167,6 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]:
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
if self.max_tokens is not None:
result["max_tokens"] = self.max_tokens
if self.timeout is not None:
result["timeout"] = self.timeout
if self.extra_body is not None and self.extra_body != {}:
result["extra_body"] = self.extra_body
return result

@model_validator(mode="after")
Expand Down Expand Up @@ -205,12 +213,89 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -
return min_value <= value <= max_value


# Maintain backwards compatibility with a deprecation warning
class InferenceParameters(CompletionInferenceParameters):
"""
Deprecated: Use CompletionInferenceParameters instead.
This alias will be removed in a future version.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.warning(
"InferenceParameters is deprecated and will be removed in a future version. "
"Use CompletionInferenceParameters instead."
)
super().__init__(*args, **kwargs)


class EmbeddingInferenceParameters(BaseInferenceParameters):
encoding_format: Optional[Literal["float", "base64"]] = None
dimensions: Optional[int] = None

@property
def generate_kwargs(self) -> dict[str, Union[float, int]]:
result = super().generate_kwargs
if self.encoding_format is not None:
result["encoding_format"] = self.encoding_format
if self.dimensions is not None:
result["dimensions"] = self.dimensions
return result


class ImageGenerationInferenceParameters(BaseInferenceParameters):
quality: str
size: str
output_format: Optional[ModalityDataType] = ModalityDataType.BASE64

@property
def generate_kwargs(self) -> dict[str, Any]:
result = super().generate_kwargs
result["size"] = self.size
result["quality"] = self.quality
result["response_format"] = "b64_json" if self.output_format == ModalityDataType.BASE64 else self.output_format
return result


InferenceParametersT: TypeAlias = Union[
InferenceParameters, CompletionInferenceParameters, EmbeddingInferenceParameters, ImageGenerationInferenceParameters
]


class GenerationType(str, Enum):
CHAT_COMPLETION = "chat-completion"
EMBEDDING = "embedding"
IMAGE_GENERATION = "image-generation"


class ModelConfig(ConfigBase):
alias: str
model: str
inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
inference_parameters: InferenceParametersT = Field(default_factory=CompletionInferenceParameters)
generation_type: Optional[GenerationType] = Field(default=GenerationType.CHAT_COMPLETION)
provider: Optional[str] = None

@model_validator(mode="after")
def _normalize_deprecated_inference_parameters(self) -> Self:
"""Normalize deprecated InferenceParameters to CompletionInferenceParameters."""
if isinstance(self.inference_parameters, InferenceParameters):
self.inference_parameters = CompletionInferenceParameters(**self.inference_parameters.model_dump())
return self

@model_validator(mode="after")
def _validate_generation_type(self) -> Self:
generation_type_instance_map = {
GenerationType.CHAT_COMPLETION: CompletionInferenceParameters,
GenerationType.EMBEDDING: EmbeddingInferenceParameters,
GenerationType.IMAGE_GENERATION: ImageGenerationInferenceParameters,
}
if self.generation_type not in generation_type_instance_map:
raise ValueError(f"Invalid generation type: {self.generation_type}")
if not isinstance(self.inference_parameters, generation_type_instance_map[self.generation_type]):
raise ValueError(
f"Inference parameters must be an instance of {generation_type_instance_map[self.generation_type].__name__!r} when generation_type is {self.generation_type!r}"
)
return self


class ModelProvider(ConfigBase):
name: str
Expand Down
15 changes: 14 additions & 1 deletion src/data_designer/config/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import cached_property
import json
import os
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -171,13 +171,18 @@ def display_sample_record(
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
+ config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
)
if len(non_code_columns) > 0:
table = Table(title="Generated Columns", **table_kws)
table.add_column("Name")
table.add_column("Value")
for col in non_code_columns:
if not col.drop:
if col.column_type == DataDesignerColumnType.EMBEDDING:
record[col.name]["embeddings"] = [
get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
]
table.add_row(col.name, convert_to_row_element(record[col.name]))
render_list.append(pad_console_element(table))

Expand Down Expand Up @@ -237,6 +242,14 @@ def display_sample_record(
console.print(Group(*render_list), markup=False)


def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
if len(long_list) > max_items:
truncated_part = long_list[:max_items]
return f"[{', '.join(str(x) for x in truncated_part)} ...]"
else:
return str(long_list)


def display_sampler_table(
sampler_params: dict[SamplerType, ConfigBase],
title: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SingleColumnConfig,
ValidationColumnConfig,
)
from data_designer.engine.column_generators.generators.llm_generators import (
from data_designer.engine.column_generators.utils.prompt_renderer import (
PromptType,
RecordBasedPromptRenderer,
create_response_recipe,
Expand Down
Loading
Loading