Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/data_designer/config/column_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,27 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
"""

column_type: Literal["seed-dataset"] = "seed-dataset"


class EmbeddingColumnConfig(SingleColumnConfig):
"""Configuration for embedding generation columns.

Embedding columns generate embeddings for text input using a specified model.

Attributes:
column_type: Discriminator field, always "embedding" for this configuration type.
target_column: The column to generate embeddings for.
model_alias: The model to use for embedding generation.
chunk_pattern: Optional regex pattern to split the text in the target column into chunks. For example, if chunk_pattern
is r'\n+', the text will be split into chunks using one or more newlines as separators and embeddings generated for each chunk.
If not provided, the entire text will be embedded as a single chunk.
"""

column_type: Literal["embedding"] = "embedding"
target_column: str
model_alias: str
chunk_pattern: Optional[str] = None

@property
def required_columns(self) -> list[str]:
return [self.target_column]
8 changes: 8 additions & 0 deletions src/data_designer/config/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..plugin_manager import PluginManager
from .column_configs import (
EmbeddingColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
Expand All @@ -31,6 +32,7 @@
SamplerColumnConfig,
SeedDatasetColumnConfig,
ValidationColumnConfig,
EmbeddingColumnConfig,
]
ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT)

Expand All @@ -50,6 +52,7 @@
DataDesignerColumnType.SEED_DATASET: "🌱",
DataDesignerColumnType.SAMPLER: "🎲",
DataDesignerColumnType.VALIDATION: "🔍",
DataDesignerColumnType.EMBEDDING: "🧬",
}
COLUMN_TYPE_EMOJI_MAP.update(
{DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()}
Expand All @@ -66,6 +69,7 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EMBEDDING,
}
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
return column_type in dag_column_types
Expand All @@ -79,6 +83,7 @@ def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
}
llm_generated_column_types.update(
plugin_manager.get_plugin_column_types(
Expand Down Expand Up @@ -117,6 +122,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
if column_type == DataDesignerColumnType.SEED_DATASET:
return SeedDatasetColumnConfig(name=name, **kwargs)
if column_type == DataDesignerColumnType.EMBEDDING:
return EmbeddingColumnConfig(name=name, **kwargs)
if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value):
return plugin.config_cls(name=name, **kwargs)
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
Expand All @@ -131,6 +138,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]:
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
]
Expand Down
23 changes: 13 additions & 10 deletions src/data_designer/config/default_model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Any, Literal, Optional

from .models import InferenceParameters, ModelConfig, ModelProvider
from .models import CompletionInferenceParameters, ModelConfig, ModelProvider
from .utils.constants import (
MANAGED_ASSETS_PATH,
MODEL_CONFIGS_FILE_PATH,
Expand All @@ -21,28 +21,30 @@
logger = logging.getLogger(__name__)


def get_default_text_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
def get_default_text_alias_inference_parameters() -> CompletionInferenceParameters:
return CompletionInferenceParameters(
temperature=0.85,
top_p=0.95,
)


def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
def get_default_reasoning_alias_inference_parameters() -> CompletionInferenceParameters:
return CompletionInferenceParameters(
temperature=0.35,
top_p=0.95,
)


def get_default_vision_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
def get_default_vision_alias_inference_parameters() -> CompletionInferenceParameters:
return CompletionInferenceParameters(
temperature=0.85,
top_p=0.95,
)


def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
def get_default_inference_parameters(
model_alias: Literal["text", "reasoning", "vision"],
) -> CompletionInferenceParameters:
if model_alias == "reasoning":
return get_default_reasoning_alias_inference_parameters()
elif model_alias == "vision":
Expand Down Expand Up @@ -103,15 +105,16 @@ def resolve_seed_default_model_settings() -> None:
f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
)
save_config_file(
MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
MODEL_CONFIGS_FILE_PATH,
{"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
)

if not MODEL_PROVIDERS_FILE_PATH.exists():
logger.debug(
f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
)
save_config_file(
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
)

if not MANAGED_ASSETS_PATH.exists():
Expand Down
84 changes: 74 additions & 10 deletions src/data_designer/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
import logging
from pathlib import Path
from typing import Any, Generic, List, Optional, TypeVar, Union
from typing import Any, Generic, List, Literal, Optional, TypeVar, Union

import numpy as np
from pydantic import BaseModel, Field, model_validator
Expand Down Expand Up @@ -136,17 +136,29 @@ def sample(self) -> float:
DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]


class InferenceParameters(ConfigBase):
temperature: Optional[Union[float, DistributionT]] = None
top_p: Optional[Union[float, DistributionT]] = None
max_tokens: Optional[int] = Field(default=None, ge=1)
class BaseInferenceParameters(ConfigBase, ABC):
max_parallel_requests: int = Field(default=4, ge=1)
timeout: Optional[int] = Field(default=None, ge=1)
extra_body: Optional[dict[str, Any]] = None

@property
def generate_kwargs(self) -> dict[str, Union[float, int]]:
result = {}
if self.timeout is not None:
result["timeout"] = self.timeout
if self.extra_body is not None and self.extra_body != {}:
result["extra_body"] = self.extra_body
return result


class CompletionInferenceParameters(BaseInferenceParameters):
temperature: Optional[Union[float, DistributionT]] = None
top_p: Optional[Union[float, DistributionT]] = None
max_tokens: Optional[int] = Field(default=None, ge=1)

@property
def generate_kwargs(self) -> dict[str, Union[float, int]]:
result = super().generate_kwargs
if self.temperature is not None:
result["temperature"] = (
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
Expand All @@ -155,10 +167,6 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]:
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
if self.max_tokens is not None:
result["max_tokens"] = self.max_tokens
if self.timeout is not None:
result["timeout"] = self.timeout
if self.extra_body is not None and self.extra_body != {}:
result["extra_body"] = self.extra_body
return result

@model_validator(mode="after")
Expand Down Expand Up @@ -205,12 +213,68 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -
return min_value <= value <= max_value


# Maintain backwards compatibility with a deprecation warning
class InferenceParameters(CompletionInferenceParameters):
"""
Deprecated: Use CompletionInferenceParameters instead.
This alias will be removed in a future version.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.warning(
"InferenceParameters is deprecated and will be removed in a future version. "
"Use CompletionInferenceParameters instead."
)
super().__init__(*args, **kwargs)


class EmbeddingInferenceParameters(BaseInferenceParameters):
encoding_format: Optional[Literal["float", "base64"]] = None
dimensions: Optional[int] = None

@property
def generate_kwargs(self) -> dict[str, Union[float, int]]:
result = super().generate_kwargs
if self.encoding_format is not None:
result["encoding_format"] = self.encoding_format
if self.dimensions is not None:
result["dimensions"] = self.dimensions
return result


InferenceParametersT: TypeAlias = Union[
InferenceParameters, CompletionInferenceParameters, EmbeddingInferenceParameters
]


class GenerationType(str, Enum):
CHAT_COMPLETION = "chat-completion"
EMBEDDING = "embedding"
IMAGE_GENERATION = "image-generation"


class ModelConfig(ConfigBase):
alias: str
model: str
inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
inference_parameters: InferenceParametersT = Field(default_factory=CompletionInferenceParameters)
provider: Optional[str] = None

@model_validator(mode="after")
def _normalize_deprecated_inference_parameters(self) -> Self:
"""Normalize deprecated InferenceParameters to CompletionInferenceParameters."""
if isinstance(self.inference_parameters, InferenceParameters):
self.inference_parameters = CompletionInferenceParameters(**self.inference_parameters.model_dump())
return self

@property
def generation_type(self) -> GenerationType:
if isinstance(self.inference_parameters, CompletionInferenceParameters):
return GenerationType.CHAT_COMPLETION
elif isinstance(self.inference_parameters, EmbeddingInferenceParameters):
return GenerationType.EMBEDDING
else:
raise ValueError(f"Unsupported inference parameters type: {type(self.inference_parameters)}")


class ModelProvider(ConfigBase):
name: str
Expand Down
15 changes: 14 additions & 1 deletion src/data_designer/config/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import cached_property
import json
import os
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -171,13 +171,18 @@ def display_sample_record(
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
+ config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
)
if len(non_code_columns) > 0:
table = Table(title="Generated Columns", **table_kws)
table.add_column("Name")
table.add_column("Value")
for col in non_code_columns:
if not col.drop:
if col.column_type == DataDesignerColumnType.EMBEDDING:
record[col.name]["embeddings"] = [
get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
]
table.add_row(col.name, convert_to_row_element(record[col.name]))
render_list.append(pad_console_element(table))

Expand Down Expand Up @@ -237,6 +242,14 @@ def display_sample_record(
console.print(Group(*render_list), markup=False)


def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
if len(long_list) > max_items:
truncated_part = long_list[:max_items]
return f"[{', '.join(str(x) for x in truncated_part)} ...]"
else:
return str(long_list)


def display_sampler_table(
sampler_params: dict[SamplerType, ConfigBase],
title: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SingleColumnConfig,
ValidationColumnConfig,
)
from data_designer.engine.column_generators.generators.llm_generators import (
from data_designer.engine.column_generators.utils.prompt_renderer import (
PromptType,
RecordBasedPromptRenderer,
create_response_recipe,
Expand Down
48 changes: 48 additions & 0 deletions src/data_designer/engine/column_generators/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
import functools
import logging
from typing import overload

import pandas as pd

from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
from data_designer.config.models import BaseInferenceParameters, ModelConfig
from data_designer.config.utils.type_helpers import StrEnum
from data_designer.engine.column_generators.utils.prompt_renderer import (
RecordBasedPromptRenderer,
)
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
from data_designer.engine.models.facade import ModelFacade

logger = logging.getLogger(__name__)


class GenerationStrategy(StrEnum):
Expand Down Expand Up @@ -59,3 +69,41 @@ def can_generate_from_scratch(self) -> bool:

@abstractmethod
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...


class WithModelGeneration:
@functools.cached_property
def model(self) -> ModelFacade:
return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)

@functools.cached_property
def model_config(self) -> ModelConfig:
return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)

@functools.cached_property
def inference_parameters(self) -> BaseInferenceParameters:
return self.model_config.inference_parameters

@functools.cached_property
def prompt_renderer(self) -> RecordBasedPromptRenderer:
return RecordBasedPromptRenderer(
response_recipe=self.response_recipe,
error_message_context={
"column_name": self.config.name,
"column_type": self.config.column_type,
"model_alias": self.config.model_alias,
},
)

def log_pre_generation(self) -> None:
emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
logger.info(f" |-- column name: {self.config.name!r}")
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
if self.model_config.provider is None:
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")

def _get_provider_name(self) -> str:
model_alias = self.model_config.alias
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
return provider.name
Loading