Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ ModelConfig(
top_p=0.9,
max_tokens=4096,
),
)
)
```

The value `openai/gpt-oss-20b` would be collected.
Expand Down
3 changes: 0 additions & 3 deletions docs/plugins/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
name="index-multiplier",
description="Generates values by multiplying the row index by a user-specified multiplier",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -110,7 +109,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
- `metadata()` describes your generator and its requirements
- `generation_strategy` can be `FULL_COLUMN`, `CELL_BY_CELL`
- You have access to the configuration parameters via `self.config`
- `required_resources` lists any required resources (models, artifact storages, etc.). This parameter will evolve in the near future, so keeping it as `None` is safe for now. That said, if your task will use the model registry, adding `data_designer.engine.resources.ResourceType.MODEL_REGISTRY` will enable automatic model health checking for your column generation task.

!!! info "Understanding generation_strategy"
The `generation_strategy` specifies how the column generator will generate data.
Expand Down Expand Up @@ -179,7 +177,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
name="index-multiplier",
description="Generates values by multiplying the row index by a user-specified multiplier",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
35 changes: 0 additions & 35 deletions src/data_designer/config/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,6 @@
)


def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
"""Return True if the column type is used in the workflow execution DAG."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
dag_column_types = {
DataDesignerColumnType.EXPRESSION,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EMBEDDING,
}
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
return column_type in dag_column_types


def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
"""Return True if the column type is a model-generated column."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
model_generated_column_types = {
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
}
model_generated_column_types.update(
plugin_manager.get_plugin_column_types(
DataDesignerColumnType,
required_resources=["model_registry"],
)
)
return column_type in model_generated_column_types


def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
"""Create a Data Designer column config object from kwargs.

Expand Down
18 changes: 0 additions & 18 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from data_designer.config.column_types import (
ColumnConfigT,
DataDesignerColumnType,
column_type_is_model_generated,
get_column_config_from_kwargs,
get_column_display_order,
)
Expand Down Expand Up @@ -422,23 +421,6 @@ def get_constraints(self, target_column: str) -> list[ColumnConstraintT]:
"""
return [c for c in self._constraints if c.target_column == target_column]

def get_llm_gen_columns(self) -> list[ColumnConfigT]:
"""Get all model-generated column configurations.

Returns:
A list of column configurations that use model generation.
"""
logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.")
return self.get_model_gen_columns()

def get_model_gen_columns(self) -> list[ColumnConfigT]:
"""Get all model-generated column configurations.

Returns:
A list of column configurations that use model generation.
"""
return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)]

def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
"""Get all column configurations of the specified type.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.recipes.response_recipes import TextResponseRecipe
from data_designer.engine.resources.resource_provider import ResourceType

logger = logging.getLogger(__name__)

Expand All @@ -42,7 +41,6 @@ def metadata() -> ColumnProfilerMetadata:
return ColumnProfilerMetadata(
name="judge_score_profiler",
description="Analyzes LLM-as-judge score distributions in a Data Designer dataset.",
required_resources=[ResourceType.MODEL_REGISTRY],
applicable_column_types=[DataDesignerColumnType.LLM_JUDGE],
)

Expand Down
39 changes: 26 additions & 13 deletions src/data_designer/engine/column_generators/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from data_designer.config.models import BaseInferenceParams, ModelConfig
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.registry import ModelRegistry


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,27 +73,39 @@ def can_generate_from_scratch(self) -> bool:
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...


class WithModelGeneration:
class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC):
@property
def model_registry(self) -> ModelRegistry:
return self.resource_provider.model_registry

def get_model(self, model_alias: str) -> ModelFacade:
return self.model_registry.get_model(model_alias=model_alias)

def get_model_config(self, model_alias: str) -> ModelConfig:
return self.model_registry.get_model_config(model_alias=model_alias)

def get_model_provider_name(self, model_alias: str) -> str:
provider = self.model_registry.get_model_provider(model_alias=model_alias)
return provider.name


class ColumnGeneratorWithModel(ColumnGeneratorWithModelRegistry[TaskConfigT], ABC):
@functools.cached_property
def model(self) -> ModelFacade:
return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
return self.get_model(model_alias=self.config.model_alias)

@functools.cached_property
def model_config(self) -> ModelConfig:
return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
return self.get_model_config(model_alias=self.config.model_alias)

@functools.cached_property
def inference_parameters(self) -> BaseInferenceParams:
return self.model_config.inference_parameters

def log_pre_generation(self) -> None:
logger.info(f"Preparing {self.config.column_type} column generation")
logger.info(f" |-- column name: {self.config.name!r}")
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
if self.model_config.provider is None:
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")

def _get_provider_name(self) -> str:
model_alias = self.model_config.alias
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
return provider.name
logger.info(f"{self.config.column_type} model configuration for generating column '{self.config.name}'")
logger.info(f" |-- model: {self.model_config.model!r}")
logger.info(f" |-- model alias: {self.config.model_alias!r}")
logger.info(f" |-- model provider: {self.get_model_provider_name(model_alias=self.config.model_alias)!r}")
logger.info(f" |-- generation type: {self.model_config.generation_type.value!r}")
logger.info(f" |-- inference parameters: {self.inference_parameters.format_for_display()}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

from data_designer.config.column_configs import EmbeddingColumnConfig
from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
ColumnGeneratorWithModel,
GenerationStrategy,
GeneratorMetadata,
WithModelGeneration,
)
from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
from data_designer.engine.resources.resource_provider import ResourceType


class EmbeddingGenerationResult(BaseModel):
Expand All @@ -27,14 +25,13 @@ def dimension(self) -> int:
return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0


class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]):
class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="embedding_cell_generator",
description="Generate embeddings for a text column.",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)

def generate(self, data: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging

import pandas as pd
Expand All @@ -25,7 +27,6 @@ def metadata() -> GeneratorMetadata:
name="expression_generator",
description="Generate a column from a jinja2 expression.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
)
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
ColumnGeneratorWithModel,
GenerationStrategy,
GeneratorMetadata,
WithModelGeneration,
)
from data_designer.engine.column_generators.utils.prompt_renderer import (
PromptType,
RecordBasedPromptRenderer,
create_response_recipe,
)
from data_designer.engine.configurable_task import TaskConfigT
from data_designer.engine.models.recipes.base import ResponseRecipe
from data_designer.engine.processing.utils import deserialize_json_values
from data_designer.engine.resources.resource_provider import ResourceType

logger = logging.getLogger(__name__)

Expand All @@ -33,7 +32,7 @@
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0


class WithChatCompletionGeneration(WithModelGeneration):
class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfigT]):
@functools.cached_property
def response_recipe(self) -> ResponseRecipe:
return create_response_recipe(self.config, self.model_config)
Expand Down Expand Up @@ -92,47 +91,43 @@ def generate(self, data: dict) -> dict:
return data


class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_text_generator",
description="Generate a new dataset cell from a prompt template",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)


class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_code_generator",
description="Generate a new dataset cell from a prompt template",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)


class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_structured_generator",
description="Generate a new dataset cell from a prompt template",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)


class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_judge_generator",
description="Judge a new dataset cell based on a set of rubrics",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging
import random
from functools import partial
Expand All @@ -17,7 +19,6 @@
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.processing.utils import concat_datasets
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
from data_designer.engine.resources.resource_provider import ResourceType
from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
Expand All @@ -32,7 +33,6 @@ def metadata() -> GeneratorMetadata:
name="sampler_column_generator",
description="Generate columns using sampling-based method.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=[ResourceType.BLOB_STORAGE],
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from __future__ import annotations

import functools
import logging

Expand All @@ -16,7 +19,6 @@
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
from data_designer.engine.processing.utils import concat_datasets
from data_designer.engine.resources.resource_provider import ResourceType

MAX_ZERO_RECORD_RESPONSE_FACTOR = 2

Expand All @@ -30,7 +32,6 @@ def metadata() -> GeneratorMetadata:
name="seed_dataset_column_generator",
description="Sample columns from a seed dataset.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=[ResourceType.SEED_READER],
)

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging

import pandas as pd
Expand Down Expand Up @@ -50,7 +52,6 @@ def metadata() -> GeneratorMetadata:
name="validate",
description="Validate data.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
Loading