Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ ModelConfig(
top_p=0.9,
max_tokens=4096,
),
)
)
```

The value `openai/gpt-oss-20b` would be collected.
Expand Down
3 changes: 0 additions & 3 deletions docs/plugins/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
name="index-multiplier",
description="Generates values by multiplying the row index by a user-specified multiplier",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -110,7 +109,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
- `metadata()` describes your generator and its requirements
- `generation_strategy` can be `FULL_COLUMN`, `CELL_BY_CELL`
- You have access to the configuration parameters via `self.config`
- `required_resources` lists any required resources (models, artifact storages, etc.). This parameter will evolve in the near future, so keeping it as `None` is safe for now. That said, if your task will use the model registry, adding `data_designer.engine.resources.ResourceType.MODEL_REGISTRY` will enable automatic model health checking for your column generation task.

!!! info "Understanding generation_strategy"
The `generation_strategy` specifies how the column generator will generate data.
Expand Down Expand Up @@ -179,7 +177,6 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
name="index-multiplier",
description="Generates values by multiplying the row index by a user-specified multiplier",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
35 changes: 0 additions & 35 deletions src/data_designer/config/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,6 @@
)


def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool:
"""Return True if the column type is used in the workflow execution DAG."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
dag_column_types = {
DataDesignerColumnType.EXPRESSION,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EMBEDDING,
}
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
return column_type in dag_column_types


def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool:
"""Return True if the column type is a model-generated column."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
model_generated_column_types = {
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
DataDesignerColumnType.EMBEDDING,
}
model_generated_column_types.update(
plugin_manager.get_plugin_column_types(
DataDesignerColumnType,
required_resources=["model_registry"],
)
)
return column_type in model_generated_column_types


def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
"""Create a Data Designer column config object from kwargs.

Expand Down
18 changes: 0 additions & 18 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from data_designer.config.column_types import (
ColumnConfigT,
DataDesignerColumnType,
column_type_is_model_generated,
get_column_config_from_kwargs,
get_column_display_order,
)
Expand Down Expand Up @@ -422,23 +421,6 @@ def get_constraints(self, target_column: str) -> list[ColumnConstraintT]:
"""
return [c for c in self._constraints if c.target_column == target_column]

def get_llm_gen_columns(self) -> list[ColumnConfigT]:
"""Get all model-generated column configurations.

Returns:
A list of column configurations that use model generation.
"""
logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.")
return self.get_model_gen_columns()

def get_model_gen_columns(self) -> list[ColumnConfigT]:
"""Get all model-generated column configurations.

Returns:
A list of column configurations that use model generation.
"""
return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)]

def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
"""Get all column configurations of the specified type.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.recipes.response_recipes import TextResponseRecipe
from data_designer.engine.resources.resource_provider import ResourceType

logger = logging.getLogger(__name__)

Expand All @@ -42,7 +41,6 @@ def metadata() -> ColumnProfilerMetadata:
return ColumnProfilerMetadata(
name="judge_score_profiler",
description="Analyzes LLM-as-judge score distributions in a Data Designer dataset.",
required_resources=[ResourceType.MODEL_REGISTRY],
applicable_column_types=[DataDesignerColumnType.LLM_JUDGE],
)

Expand Down
39 changes: 26 additions & 13 deletions src/data_designer/engine/column_generators/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from data_designer.config.models import BaseInferenceParams, ModelConfig
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.registry import ModelRegistry


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,27 +73,39 @@ def can_generate_from_scratch(self) -> bool:
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...


class WithModelGeneration:
class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC):
@property
def model_registry(self) -> ModelRegistry:
return self.resource_provider.model_registry

def get_model(self, model_alias: str) -> ModelFacade:
return self.model_registry.get_model(model_alias=model_alias)

def get_model_config(self, model_alias: str) -> ModelConfig:
return self.model_registry.get_model_config(model_alias=model_alias)

def get_model_provider_name(self, model_alias: str) -> str:
provider = self.model_registry.get_model_provider(model_alias=model_alias)
return provider.name


class ColumnGeneratorWithModel(ColumnGeneratorWithModelRegistry[TaskConfigT], ABC):
@functools.cached_property
def model(self) -> ModelFacade:
return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
return self.get_model(model_alias=self.config.model_alias)

@functools.cached_property
def model_config(self) -> ModelConfig:
return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
return self.get_model_config(model_alias=self.config.model_alias)

@functools.cached_property
def inference_parameters(self) -> BaseInferenceParams:
return self.model_config.inference_parameters

def log_pre_generation(self) -> None:
logger.info(f"Preparing {self.config.column_type} column generation")
logger.info(f" |-- column name: {self.config.name!r}")
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
if self.model_config.provider is None:
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")

def _get_provider_name(self) -> str:
model_alias = self.model_config.alias
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
return provider.name
logger.info(f"{self.config.column_type} model configuration for generating column '{self.config.name}'")
logger.info(f" |-- model: {self.model_config.model!r}")
logger.info(f" |-- model alias: {self.config.model_alias!r}")
logger.info(f" |-- model provider: {self.get_model_provider_name(model_alias=self.config.model_alias)!r}")
logger.info(f" |-- generation type: {self.model_config.generation_type.value!r}")
logger.info(f" |-- inference parameters: {self.inference_parameters.format_for_display()}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

from data_designer.config.column_configs import EmbeddingColumnConfig
from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
ColumnGeneratorWithModel,
GenerationStrategy,
GeneratorMetadata,
WithModelGeneration,
)
from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
from data_designer.engine.resources.resource_provider import ResourceType


class EmbeddingGenerationResult(BaseModel):
Expand All @@ -27,14 +25,13 @@ def dimension(self) -> int:
return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0


class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]):
class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="embedding_cell_generator",
description="Generate embeddings for a text column.",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)

def generate(self, data: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging

import pandas as pd
Expand All @@ -25,7 +27,6 @@ def metadata() -> GeneratorMetadata:
name="expression_generator",
description="Generate a column from a jinja2 expression.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
)
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
ColumnGeneratorWithModel,
GenerationStrategy,
GeneratorMetadata,
WithModelGeneration,
)
from data_designer.engine.column_generators.utils.prompt_renderer import (
PromptType,
RecordBasedPromptRenderer,
create_response_recipe,
)
from data_designer.engine.configurable_task import TaskConfigT
from data_designer.engine.models.recipes.base import ResponseRecipe
from data_designer.engine.processing.utils import deserialize_json_values
from data_designer.engine.resources.resource_provider import ResourceType

logger = logging.getLogger(__name__)

Expand All @@ -33,7 +32,7 @@
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0


class WithChatCompletionGeneration(WithModelGeneration):
class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfigT]):
@functools.cached_property
def response_recipe(self) -> ResponseRecipe:
return create_response_recipe(self.config, self.model_config)
Expand Down Expand Up @@ -92,47 +91,43 @@ def generate(self, data: dict) -> dict:
return data


class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_text_generator",
description="Generate a new dataset cell from a prompt template",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)


class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_code_generator",
description="Generate a new dataset cell from a prompt template",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)


class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_structured_generator",
description="Generate a new dataset cell from a prompt template",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)


class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="llm_judge_generator",
description="Judge a new dataset cell based on a set of rubrics",
generation_strategy=GenerationStrategy.CELL_BY_CELL,
required_resources=[ResourceType.MODEL_REGISTRY],
)

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging
import random
from functools import partial
Expand All @@ -17,7 +19,6 @@
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.processing.utils import concat_datasets
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
from data_designer.engine.resources.resource_provider import ResourceType
from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
Expand All @@ -32,7 +33,6 @@ def metadata() -> GeneratorMetadata:
name="sampler_column_generator",
description="Generate columns using sampling-based method.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=[ResourceType.BLOB_STORAGE],
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from __future__ import annotations

import functools
import logging

Expand All @@ -16,7 +19,6 @@
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
from data_designer.engine.processing.utils import concat_datasets
from data_designer.engine.resources.resource_provider import ResourceType

MAX_ZERO_RECORD_RESPONSE_FACTOR = 2

Expand All @@ -30,7 +32,6 @@ def metadata() -> GeneratorMetadata:
name="seed_dataset_column_generator",
description="Sample columns from a seed dataset.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=[ResourceType.SEED_READER],
)

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging

import pandas as pd
Expand Down Expand Up @@ -50,7 +52,6 @@ def metadata() -> GeneratorMetadata:
name="validate",
description="Validate data.",
generation_strategy=GenerationStrategy.FULL_COLUMN,
required_resources=None,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
Loading