diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 78cfe724..ce0ddf8a 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -24,7 +24,6 @@ ) from .data_designer_config import DataDesignerConfig from .dataset_builders import BuildStage -from .datastore import DatastoreSettings, fetch_seed_dataset_column_names from .errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError from .models import ModelConfig, load_model_configs from .processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs @@ -35,19 +34,16 @@ ScalarInequalityConstraint, ) from .seed import ( - DatastoreSeedDatasetReference, IndexRange, - LocalSeedDatasetReference, PartitionBlock, SamplingStrategy, SeedConfig, - SeedDatasetReference, + SeedDatasetReferenceT, ) from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE from .utils.info import DataDesignerInfo from .utils.io_helpers import serialize_data, smart_load_yaml from .utils.misc import ( - can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake, ) @@ -61,17 +57,16 @@ class BuilderConfig(ExportableConfigBase): """Configuration container for Data Designer builder. This class holds the main Data Designer configuration along with optional - datastore settings needed for seed dataset operations. + seed dataset reference settings. Attributes: data_designer: The main Data Designer configuration containing columns, constraints, profilers, and other settings. - datastore_settings: Optional datastore settings for accessing external - datasets. + seed_dataset_reference: Information about the seed dataset, if one is in use. """ data_designer: DataDesignerConfig - datastore_settings: Optional[DatastoreSettings] + seed_dataset_reference: Optional[SeedDatasetReferenceT] class DataDesignerConfigBuilder: @@ -104,30 +99,22 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self: builder_config = BuilderConfig.model_validate(json_config) builder = cls(model_configs=builder_config.data_designer.model_configs) - config = builder_config.data_designer + dd_config = builder_config.data_designer - for col in config.columns: + for col in dd_config.columns: builder.add_column(col) - for constraint in config.constraints or []: + for constraint in dd_config.constraints or []: builder.add_constraint(constraint=constraint) - if config.seed_config: - if builder_config.datastore_settings is None: - if can_run_data_designer_locally(): - seed_dataset_reference = LocalSeedDatasetReference(dataset=config.seed_config.dataset) - else: - raise BuilderConfigurationError("🛑 Datastore settings are required.") - else: - seed_dataset_reference = DatastoreSeedDatasetReference( - dataset=config.seed_config.dataset, - datastore_settings=builder_config.datastore_settings, - ) - builder.set_seed_datastore_settings(builder_config.datastore_settings) + if dd_config.seed_config: + if (seed_dataset_reference := builder_config.seed_dataset_reference) is None: + # TODO: Should this just log a warning and recommend re-running with_seed_dataset, or raise? + raise BuilderConfigurationError("🛑 Found seed_config without seed_dataset_reference.") builder.with_seed_dataset( seed_dataset_reference, - sampling_strategy=config.seed_config.sampling_strategy, - selection_strategy=config.seed_config.selection_strategy, + sampling_strategy=dd_config.seed_config.sampling_strategy, + selection_strategy=dd_config.seed_config.selection_strategy, ) return builder @@ -148,7 +135,7 @@ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] self._constraints: list[ColumnConstraintT] = [] self._profilers: list[ColumnProfilerConfigT] = [] self._info = DataDesignerInfo() - self._datastore_settings: Optional[DatastoreSettings] = None + self._seed_dataset_reference: Optional[SeedDatasetReferenceT] = None @property def model_configs(self) -> list[ModelConfig]: @@ -497,13 +484,13 @@ def get_seed_config(self) -> Optional[SeedConfig]: """ return self._seed_config - def get_seed_datastore_settings(self) -> Optional[DatastoreSettings]: - """Get most recent datastore settings for the current Data Designer configuration. + def get_seed_dataset_reference(self) -> Optional[SeedDatasetReferenceT]: + """Get the seed dataset reference for the current Data Designer configuration. Returns: - The datastore settings if configured, None otherwise. + The seed dataset reference if configured, None otherwise. """ - return None if not self._datastore_settings else DatastoreSettings.model_validate(self._datastore_settings) + return self._seed_dataset_reference def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: """Get the count of columns of the specified type. @@ -516,13 +503,13 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: """ return len(self.get_columns_of_type(column_type)) - def set_seed_datastore_settings(self, datastore_settings: Optional[DatastoreSettings]) -> Self: - """Set the datastore settings for the seed dataset. + def set_seed_dataset_reference(self, seed_dataset_reference: Optional[SeedDatasetReferenceT]) -> Self: + """Set the dataset reference. Args: - datastore_settings: The datastore settings to use for the seed dataset. + seed_dataset_reference: The seed dataset reference. """ - self._datastore_settings = datastore_settings + self._seed_dataset_reference = seed_dataset_reference return self def validate(self, *, raise_exceptions: bool = False) -> Self: @@ -554,7 +541,7 @@ def validate(self, *, raise_exceptions: bool = False) -> Self: def with_seed_dataset( self, - dataset_reference: SeedDatasetReference, + dataset_reference: SeedDatasetReferenceT, *, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None, @@ -563,25 +550,25 @@ def with_seed_dataset( This method sets the seed dataset for the configuration and automatically creates SeedDatasetColumnConfig objects for each column found in the dataset. The column - names are fetched from the dataset source (Hugging Face Hub or NeMo Microservices Datastore). + names are fetched from the dataset source client-side using local access credentials. Args: - dataset_reference: Seed dataset reference for fetching from the datastore. + dataset_reference: An object that contains a pointer to the dataset with any necessary + credentials for reading it. sampling_strategy: The sampling strategy to use when generating data from the seed dataset. Defaults to ORDERED sampling. Returns: The current Data Designer config builder instance. """ + self.set_seed_dataset_reference(dataset_reference) self._seed_config = SeedConfig( - dataset=dataset_reference.dataset, + dataset=dataset_reference.get_dataset(), sampling_strategy=sampling_strategy, selection_strategy=selection_strategy, + source=dataset_reference.get_source(), ) - self.set_seed_datastore_settings( - dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None - ) - for column_name in fetch_seed_dataset_column_names(dataset_reference): + for column_name in dataset_reference.get_column_names(): self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name) return self @@ -611,7 +598,7 @@ def get_builder_config(self) -> BuilderConfig: Returns: The builder config. """ - return BuilderConfig(data_designer=self.build(), datastore_settings=self._datastore_settings) + return BuilderConfig(data_designer=self.build(), seed_dataset_reference=self.get_seed_dataset_reference()) def __repr__(self) -> str: """Generates a string representation of the DataDesignerConfigBuilder instance. diff --git a/src/data_designer/config/datastore.py b/src/data_designer/config/datastore.py deleted file mode 100644 index 9f625a13..00000000 --- a/src/data_designer/config/datastore.py +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union - -from huggingface_hub import HfApi, HfFileSystem -import pandas as pd -import pyarrow.parquet as pq -from pydantic import BaseModel, Field - -from .errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError -from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS, validate_path_contains_files_of_type - -if TYPE_CHECKING: - from .seed import SeedDatasetReference - -logger = logging.getLogger(__name__) - - -class DatastoreSettings(BaseModel): - """Configuration for interacting with a datastore.""" - - endpoint: str = Field( - ..., - description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.", - ) - token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.") - - -def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]: - """Extract column names based on file type. Supports glob patterns like '../path/*.parquet'.""" - file_path = Path(file_path) - if "*" in str(file_path): - matching_files = sorted(file_path.parent.glob(file_path.name)) - if not matching_files: - raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}") - logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset") - file_path = matching_files[0] - - if file_type == "parquet": - try: - schema = pq.read_schema(file_path) - if hasattr(schema, "names"): - return schema.names - else: - return [field.name for field in schema] - except Exception as e: - logger.warning(f"Failed to process parquet file {file_path}: {e}") - return [] - elif file_type in ["json", "jsonl"]: - return pd.read_json(file_path, orient="records", lines=True, nrows=1).columns.tolist() - elif file_type == "csv": - try: - df = pd.read_csv(file_path, nrows=1) - return df.columns.tolist() - except (pd.errors.EmptyDataError, pd.errors.ParserError) as e: - logger.warning(f"Failed to process CSV file {file_path}: {e}") - return [] - else: - raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}") - - -def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]: - if hasattr(seed_dataset_reference, "datastore_settings"): - return _fetch_seed_dataset_column_names_from_datastore( - seed_dataset_reference.repo_id, - seed_dataset_reference.filename, - seed_dataset_reference.datastore_settings, - ) - return _fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset) - - -def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings: - if datastore_settings is None: - raise InvalidConfigError("🛑 Datastore settings are required in order to upload datasets to the datastore.") - if isinstance(datastore_settings, DatastoreSettings): - return datastore_settings - elif isinstance(datastore_settings, dict): - return DatastoreSettings.model_validate(datastore_settings) - else: - raise InvalidConfigError( - "🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary." - ) - - -def upload_to_hf_hub( - dataset_path: Union[str, Path], - filename: str, - repo_id: str, - datastore_settings: DatastoreSettings, - **kwargs, -) -> str: - datastore_settings = resolve_datastore_settings(datastore_settings) - dataset_path = _validate_dataset_path(dataset_path) - filename_ext = filename.split(".")[-1].lower() - if dataset_path.suffix.lower()[1:] != filename_ext: - raise InvalidFileFormatError( - f"🛑 Dataset file extension {dataset_path.suffix!r} does not match `filename` extension .{filename_ext!r}" - ) - - hfapi = HfApi(endpoint=datastore_settings.endpoint, token=datastore_settings.token) - hfapi.create_repo(repo_id, exist_ok=True, repo_type="dataset") - hfapi.upload_file( - path_or_fileobj=dataset_path, - path_in_repo=filename, - repo_id=repo_id, - repo_type="dataset", - **kwargs, - ) - return f"{repo_id}/{filename}" - - -def _fetch_seed_dataset_column_names_from_datastore( - repo_id: str, - filename: str, - datastore_settings: Optional[Union[DatastoreSettings, dict]] = None, -) -> list[str]: - file_type = filename.split(".")[-1] - if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS: - raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}") - - datastore_settings = resolve_datastore_settings(datastore_settings) - fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token) - - with fs.open(f"datasets/{repo_id}/{filename}") as f: - return get_file_column_names(f, file_type) - - -def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]: - dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True) - return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1]) - - -def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path: - if allow_glob_pattern and "*" in str(dataset_path): - parts = str(dataset_path).split("*.") - file_path = parts[0] - file_extension = parts[-1] - validate_path_contains_files_of_type(file_path, file_extension) - return Path(dataset_path) - if not Path(dataset_path).is_file(): - raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.") - if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)): - raise InvalidFileFormatError( - "🛑 Dataset files must be in `parquet`, `csv`, or `json` (orient='records', lines=True) format." - ) - return Path(dataset_path) diff --git a/src/data_designer/config/seed.py b/src/data_designer/config/seed.py index 1fba60ad..c5798612 100644 --- a/src/data_designer/config/seed.py +++ b/src/data_designer/config/seed.py @@ -1,21 +1,29 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from abc import ABC +from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, Union - +import logging +import os +from pathlib import Path +from typing import Annotated, Literal, Optional, Union + +from huggingface_hub import HfFileSystem +import pandas as pd +import pyarrow.parquet as pq from pydantic import Field, field_validator, model_validator -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .base import ConfigBase -from .datastore import DatastoreSettings +from .errors import InvalidFileFormatError, InvalidFilePathError from .utils.io_helpers import ( VALID_DATASET_FILE_EXTENSIONS, validate_dataset_file_path, validate_path_contains_files_of_type, ) +logger = logging.getLogger(__name__) + class SamplingStrategy(str, Enum): ORDERED = "ordered" @@ -73,6 +81,8 @@ class SeedConfig(ConfigBase): - PartitionBlock: Select a partition by splitting the dataset into N equal parts. Partition indices are zero-based (index=0 is the first partition, index=1 is the second, etc.). + source: Optional source name if you are running in a context with pre-registered, named + sources from which seed datasets can be used. Examples: Read rows sequentially from start to end: @@ -113,33 +123,116 @@ class SeedConfig(ConfigBase): dataset: str sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None + source: Optional[str] = None class SeedDatasetReference(ABC, ConfigBase): - dataset: str + @abstractmethod + def get_dataset(self) -> str: ... + @abstractmethod + def get_source(self) -> Optional[str]: ... -class DatastoreSeedDatasetReference(SeedDatasetReference): - datastore_settings: DatastoreSettings + @abstractmethod + def get_column_names(self) -> list[str]: ... - @property - def repo_id(self) -> str: - return "/".join(self.dataset.split("/")[:-1]) - @property - def filename(self) -> str: - return self.dataset.split("/")[-1] +class LocalSeedDatasetReference(SeedDatasetReference): + reference_type: Literal["local"] = "local" + dataset: Union[str, Path] -class LocalSeedDatasetReference(SeedDatasetReference): @field_validator("dataset", mode="after") - def validate_dataset_is_file(cls, v: str) -> str: + def validate_dataset_is_file(cls, v: Union[str, Path]) -> Union[str, Path]: valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS} - if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions): - parts = v.split("*.") + if any(str(v).endswith(wildcard) for wildcard in valid_wild_card_versions): + parts = str(v).split("*.") file_path = parts[0] file_extension = parts[-1] validate_path_contains_files_of_type(file_path, file_extension) else: validate_dataset_file_path(v) return v + + def get_dataset(self) -> str: + return str(self.dataset) + + def get_source(self) -> Optional[str]: + return None + + def get_column_names(self) -> list[str]: + file_type = Path(self.dataset).suffix.lower()[1:] + return _get_file_column_names(self.dataset, file_type) + + +class HfHubSeedDatasetReference(SeedDatasetReference): + reference_type: Literal["hf_hub"] = "hf_hub" + + dataset: str + endpoint: str = "https://huggingface.co" + token: Optional[str] = None + source_name: Optional[str] = None + + def get_dataset(self) -> str: + return self.dataset + + def get_source(self) -> Optional[str]: + return self.source_name + + def get_column_names(self) -> list[str]: + filename = self.dataset.split("/")[-1] + repo_id = "/".join(self.dataset.split("/")[:-1]) + + file_type = filename.split(".")[-1] + if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS: + raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}") + + _token = self.token + if self.token is not None: + # Check if the value is an env var name and if so resolve it, + # otherwise assume the value is the raw token string in plain text + _token = os.environ.get(self.token, self.token) + + fs = HfFileSystem(endpoint=self.endpoint, token=_token) + + with fs.open(f"datasets/{repo_id}/{filename}") as f: + return _get_file_column_names(f, file_type) + + +SeedDatasetReferenceT: TypeAlias = Annotated[ + Union[LocalSeedDatasetReference, HfHubSeedDatasetReference], + Field(discriminator="reference_type"), +] + + +def _get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]: + """Extract column names based on file type.""" + file_path = Path(file_path) + if "*" in str(file_path): + matching_files = sorted(file_path.parent.glob(file_path.name)) + if not matching_files: + raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}") + logger.debug(f"0️⃣Using the first matching file in {str(file_path)!r} to determine column names in seed dataset") + file_path = matching_files[0] + + if file_type == "parquet": + try: + schema = pq.read_schema(file_path) + if hasattr(schema, "names"): + return schema.names + else: + return [field.name for field in schema] + except Exception as e: + logger.warning(f"Failed to process parquet file {file_path}: {e}") + return [] + elif file_type in ["json", "jsonl"]: + return pd.read_json(file_path, orient="records", lines=True, nrows=1).columns.tolist() + elif file_type == "csv": + try: + df = pd.read_csv(file_path, nrows=1) + return df.columns.tolist() + except (pd.errors.EmptyDataError, pd.errors.ParserError) as e: + logger.warning(f"Failed to process CSV file {file_path}: {e}") + return [] + else: + raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}") diff --git a/src/data_designer/engine/analysis/dataset_profiler.py b/src/data_designer/engine/analysis/dataset_profiler.py index 3dd39ed0..8d10ed3e 100644 --- a/src/data_designer/engine/analysis/dataset_profiler.py +++ b/src/data_designer/engine/analysis/dataset_profiler.py @@ -107,8 +107,6 @@ def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> Col def _validate_column_profiler_configs(self) -> None: if self.config.column_profiler_configs: - if self.resource_provider.model_registry is None: - raise DatasetProfilerConfigurationError("Model registry is required for column profiler configs") self._validate_model_configs() def _validate_model_configs(self) -> None: diff --git a/src/data_designer/engine/column_generators/generators/seed_dataset.py b/src/data_designer/engine/column_generators/generators/seed_dataset.py index 35578602..e6986d1e 100644 --- a/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -30,7 +30,7 @@ def metadata() -> GeneratorMetadata: name="seed_dataset_column_generator", description="Sample columns from a seed dataset.", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=[ResourceType.DATASTORE], + required_resources=[ResourceType.SEED_DATASET_REPOSITORY], ) @property @@ -39,7 +39,7 @@ def num_records_sampled(self) -> int: @functools.cached_property def duckdb_conn(self) -> duckdb.DuckDBPyConnection: - return self.resource_provider.datastore.create_duckdb_connection() + return self.resource_provider.seed_dataset_repository.create_duckdb_connection(self.config.source) def generate(self, dataset: pd.DataFrame) -> pd.DataFrame: return concat_datasets([self.generate_from_scratch(len(dataset)), dataset]) @@ -57,7 +57,9 @@ def _initialize(self) -> None: self._num_records_sampled = 0 self._batch_reader = None self._df_remaining = None - self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset) + self._dataset_uri = self.resource_provider.seed_dataset_repository.get_dataset_uri( + self.config.dataset, self.config.source + ) self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0] self._index_range = self._resolve_index_range() @@ -135,7 +137,7 @@ def _sample_records(self, num_records: int) -> pd.DataFrame: num_zero_record_responses += 1 if num_zero_record_responses > MAX_ZERO_RECORD_RESPONSE_FACTOR * num_records: raise RuntimeError( - "🛑 Something went wrong while reading from the datastore. " + "🛑 Something went wrong while reading from the seed dataset source. " "Please check your connection and try again. " "If the issue persists, please contact support." ) diff --git a/src/data_designer/engine/errors.py b/src/data_designer/engine/errors.py index 83252691..00b7e1b7 100644 --- a/src/data_designer/engine/errors.py +++ b/src/data_designer/engine/errors.py @@ -15,6 +15,9 @@ class UnknownModelAliasError(DataDesignerError): ... class UnknownProviderError(DataDesignerError): ... +class UnknownSeedDatasetSourceError(DataDesignerError): ... + + class NoModelProvidersError(DataDesignerError): ... diff --git a/src/data_designer/engine/resources/resource_provider.py b/src/data_designer/engine/resources/resource_provider.py index 7cec142d..79cf0913 100644 --- a/src/data_designer/engine/resources/resource_provider.py +++ b/src/data_designer/engine/resources/resource_provider.py @@ -8,22 +8,22 @@ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.registry import ModelRegistry, create_model_registry -from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage -from data_designer.engine.resources.seed_dataset_data_store import SeedDatasetDataStore +from data_designer.engine.resources.managed_storage import ManagedBlobStorage +from data_designer.engine.resources.seed_dataset_source import SeedDatasetRepository, SeedDatasetSourceRegistry from data_designer.engine.secret_resolver import SecretResolver class ResourceType(StrEnum): BLOB_STORAGE = "blob_storage" - DATASTORE = "datastore" + SEED_DATASET_REPOSITORY = "seed_dataset_repository" MODEL_REGISTRY = "model_registry" class ResourceProvider(ConfigBase): artifact_storage: ArtifactStorage - blob_storage: ManagedBlobStorage | None = None - datastore: SeedDatasetDataStore | None = None - model_registry: ModelRegistry | None = None + blob_storage: ManagedBlobStorage + seed_dataset_repository: SeedDatasetRepository + model_registry: ModelRegistry def create_resource_provider( @@ -32,16 +32,19 @@ def create_resource_provider( model_configs: list[ModelConfig], secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, - datastore: SeedDatasetDataStore | None = None, - blob_storage: ManagedBlobStorage | None = None, + seed_dataset_source_registry: SeedDatasetSourceRegistry, + blob_storage: ManagedBlobStorage, ) -> ResourceProvider: return ResourceProvider( artifact_storage=artifact_storage, - datastore=datastore, + seed_dataset_repository=SeedDatasetRepository( + registry=seed_dataset_source_registry, + secret_resolver=secret_resolver, + ), model_registry=create_model_registry( model_configs=model_configs, secret_resolver=secret_resolver, model_provider_registry=model_provider_registry, ), - blob_storage=blob_storage or init_managed_blob_storage(), + blob_storage=blob_storage, ) diff --git a/src/data_designer/engine/resources/seed_dataset_data_store.py b/src/data_designer/engine/resources/seed_dataset_data_store.py deleted file mode 100644 index 3295f41e..00000000 --- a/src/data_designer/engine/resources/seed_dataset_data_store.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod - -import duckdb -from huggingface_hub import HfApi, HfFileSystem - -from data_designer.logging import quiet_noisy_logger - -quiet_noisy_logger("httpx") - -_HF_DATASETS_PREFIX = "hf://datasets/" - - -class MalformedFileIdError(Exception): - """Raised when file_id format is invalid.""" - - -class SeedDatasetDataStore(ABC): - """Abstract base class for dataset storage implementations.""" - - @abstractmethod - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ... - - @abstractmethod - def get_dataset_uri(self, file_id: str) -> str: ... - - -class LocalSeedDatasetDataStore(SeedDatasetDataStore): - """Local filesystem-based dataset storage.""" - - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: - return duckdb.connect() - - def get_dataset_uri(self, file_id: str) -> str: - return file_id - - -class HfHubSeedDatasetDataStore(SeedDatasetDataStore): - """Hugging Face and Data Store dataset storage.""" - - def __init__(self, endpoint: str, token: str | None): - self.hfapi = HfApi(endpoint=endpoint, token=token) - self.hffs = HfFileSystem(endpoint=endpoint, token=token) - - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: - conn = duckdb.connect() - conn.register_filesystem(self.hffs) - return conn - - def get_dataset_uri(self, file_id: str) -> str: - identifier = file_id.removeprefix(_HF_DATASETS_PREFIX) - repo_id, filename = self._get_repo_id_and_filename(identifier) - return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}" - - def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]: - """Extract repo_id and filename from identifier.""" - parts = identifier.split("/", 2) - if len(parts) < 3: - raise MalformedFileIdError( - "Could not extract repo id and filename from file_id, " - "expected 'hf://datasets/{repo-namespace}/{repo-name}/{filename}'" - ) - repo_ns, repo_name, filename = parts - return f"{repo_ns}/{repo_name}", filename diff --git a/src/data_designer/engine/resources/seed_dataset_source.py b/src/data_designer/engine/resources/seed_dataset_source.py new file mode 100644 index 00000000..73f7898b --- /dev/null +++ b/src/data_designer/engine/resources/seed_dataset_source.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from functools import cached_property +from typing import Annotated, Literal, Self, TypeAlias + +import duckdb +from huggingface_hub import HfFileSystem +from pydantic import BaseModel, Field, field_validator, model_validator + +from data_designer.engine.errors import UnknownSeedDatasetSourceError +from data_designer.engine.secret_resolver import SecretResolver +from data_designer.logging import quiet_noisy_logger + +quiet_noisy_logger("httpx") + +_HF_DATASETS_PREFIX = "hf://datasets/" + + +class MalformedFileIdError(Exception): + """Raised when file_id format is invalid.""" + + +class SeedDatasetSource(BaseModel, ABC): + """Abstract base class for dataset storage implementations.""" + + name: str + + @abstractmethod + def resolve(self, secret_resolver: SecretResolver) -> Self: ... + + @abstractmethod + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ... + + @abstractmethod + def get_dataset_uri(self, file_id: str) -> str: ... + + +class LocalSeedDatasetSource(SeedDatasetSource): + """Local filesystem-based dataset storage.""" + + source_type: Literal["local"] = "local" + + name: str = "local" + + def resolve(self, secret_resolver: SecretResolver) -> Self: + return self.model_copy(deep=True) + + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + return duckdb.connect() + + def get_dataset_uri(self, file_id: str) -> str: + return file_id + + +class HfHubSeedDatasetSource(SeedDatasetSource): + """Hugging Face and Data Store dataset storage.""" + + source_type: Literal["hf_hub"] = "hf_hub" + + name: str = "hf_hub" + endpoint: str + token: str | None = None + + def resolve(self, secret_resolver: SecretResolver) -> Self: + update = {} + if self.token is not None: + update = {"token": secret_resolver.resolve(self.token)} + return self.model_copy(deep=True, update=update) + + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + conn = duckdb.connect() + conn.register_filesystem(HfFileSystem(endpoint=self.endpoint, token=self.token)) + return conn + + def get_dataset_uri(self, file_id: str) -> str: + identifier = file_id.removeprefix(_HF_DATASETS_PREFIX) + repo_id, filename = self._get_repo_id_and_filename(identifier) + return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}" + + def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]: + """Extract repo_id and filename from identifier.""" + parts = identifier.split("/", 2) + if len(parts) < 3: + raise MalformedFileIdError( + "Could not extract repo id and filename from file_id, " + "expected 'hf://datasets/{repo-namespace}/{repo-name}/{filename}'" + ) + repo_ns, repo_name, filename = parts + return f"{repo_ns}/{repo_name}", filename + + +SeedDatasetSourceT: TypeAlias = Annotated[ + LocalSeedDatasetSource | HfHubSeedDatasetSource, + Field(discriminator="source_type"), +] + + +class SeedDatasetSourceRegistry(BaseModel): + sources: list[SeedDatasetSourceT] + default: str | None = None + + @field_validator("sources", mode="after") + @classmethod + def validate_providers_not_empty(cls, v: list[SeedDatasetSourceT]) -> list[SeedDatasetSourceT]: + if len(v) == 0: + raise ValueError("At least one source must be defined") + return v + + @field_validator("sources", mode="after") + @classmethod + def validate_providers_have_unique_names(cls, v: list[SeedDatasetSourceT]) -> list[SeedDatasetSourceT]: + names = set() + dupes = set() + for source in v: + if source.name in names: + dupes.add(source.name) + names.add(source.name) + + if len(dupes) > 0: + raise ValueError(f"Seed dataset sources must have unique names, found duplicates: {dupes}") + return v + + @model_validator(mode="after") + def check_implicit_default(self) -> Self: + if self.default is None and len(self.sources) != 1: + raise ValueError("A default source must be specified if multiple model sources are defined") + return self + + @model_validator(mode="after") + def check_default_exists(self) -> Self: + if self.default and self.default not in self._sources_dict: + raise ValueError(f"Specified default {self.default!r} not found in sources list") + return self + + def get_default_source_name(self) -> str: + return self.default or self.sources[0].name + + @cached_property + def _sources_dict(self) -> dict[str, SeedDatasetSourceT]: + return {s.name: s for s in self.sources} + + def get_source(self, name: str | None) -> SeedDatasetSourceT: + if name is None: + name = self.get_default_source_name() + + try: + return self._sources_dict[name] + except KeyError: + raise UnknownSeedDatasetSourceError(f"No seed dataset source named {name!r} registered") + + +class SeedDatasetRepository: + def __init__( + self, + registry: SeedDatasetSourceRegistry, + secret_resolver: SecretResolver, + ): + self._registry = registry + self._secret_resolver = secret_resolver + + def create_duckdb_connection(self, source_name: str | None) -> duckdb.DuckDBPyConnection: + return self._get_resolved_source(source_name).create_duckdb_connection() + + def get_dataset_uri(self, file_id: str, source_name: str | None) -> str: + return self._get_resolved_source(source_name).get_dataset_uri(file_id) + + def _get_resolved_source(self, source_name: str | None) -> SeedDatasetSource: + unresolved_source = self._registry.get_source(source_name) + return unresolved_source.resolve(self._secret_resolver) diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 3b3fd96f..db98ea66 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -17,7 +17,6 @@ from ..config.config_builder import DataDesignerConfigBuilder from ..config.data_designer_config import DataDesignerConfig from ..config.dataset_builders import BuildStage -from ..config.datastore import DatastoreSettings from ..config.models import ( ImageContext, ImageFormat, @@ -49,7 +48,7 @@ UniformSamplerParams, UUIDSamplerParams, ) -from ..config.seed import DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig +from ..config.seed import IndexRange, PartitionBlock, SamplingStrategy, SeedConfig from ..config.utils.code_lang import CodeLang from ..config.utils.misc import can_run_data_designer_locally from ..config.validator_params import ( @@ -83,8 +82,6 @@ "DataDesignerConfig", "DataDesignerConfigBuilder", "BuildStage", - "DatastoreSeedDatasetReference", - "DatastoreSettings", "DatetimeSamplerParams", "DropColumnsProcessorConfig", "ExpressionColumnConfig", diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py index 13d16cb5..c94157de 100644 --- a/src/data_designer/interface/data_designer.py +++ b/src/data_designer/interface/data_designer.py @@ -10,7 +10,7 @@ from data_designer.config.base import DEFAULT_NUM_RECORDS, DataDesignerInterface from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.preview_results import PreviewResults -from data_designer.config.seed import LocalSeedDatasetReference +from data_designer.config.seed import HfHubSeedDatasetReference, LocalSeedDatasetReference from data_designer.config.utils.io_helpers import write_seed_dataset from data_designer.engine.analysis.dataset_profiler import ( DataDesignerDatasetProfiler, @@ -20,14 +20,19 @@ from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.model_provider import ModelProvider, resolve_model_provider_registry -from data_designer.engine.models.registry import create_model_registry from data_designer.engine.resources.managed_storage import init_managed_blob_storage -from data_designer.engine.resources.resource_provider import ResourceProvider -from data_designer.engine.resources.seed_dataset_data_store import ( - HfHubSeedDatasetDataStore, - LocalSeedDatasetDataStore, +from data_designer.engine.resources.resource_provider import ResourceProvider, create_resource_provider +from data_designer.engine.resources.seed_dataset_source import ( + HfHubSeedDatasetSource, + LocalSeedDatasetSource, + SeedDatasetSourceRegistry, +) +from data_designer.engine.secret_resolver import ( + CompositeResolver, + EnvironmentResolver, + PlaintextResolver, + SecretResolver, ) -from data_designer.engine.secret_resolver import EnvironmentResolver, SecretResolver from data_designer.interface.errors import ( DataDesignerGenerationError, DataDesignerProfilingError, @@ -38,6 +43,8 @@ DEFAULT_BUFFER_SIZE = 1000 +DEFAULT_SECRET_RESOLVER = CompositeResolver(resolvers=[EnvironmentResolver(), PlaintextResolver()]) + logger = logging.getLogger(__name__) @@ -55,7 +62,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]): model_providers: Optional list of model providers for LLM generation. If None, uses default providers. secret_resolver: Resolver for handling secrets and credentials. Defaults to - EnvironmentResolver which reads secrets from environment variables. + a resolver that first checks environment variables before assuming plaintext values. blob_storage_path: Path to the blob storage directory. Note this parameter is temporary and will be removed after we update person sampling for the library. """ @@ -65,10 +72,10 @@ def __init__( artifact_path: Path | str, *, model_providers: list[ModelProvider] | None = None, - secret_resolver: SecretResolver = EnvironmentResolver(), + secret_resolver: SecretResolver | None = None, blob_storage_path: Path | str | None = None, ): - self._secret_resolver = secret_resolver + self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER self._artifact_path = Path(artifact_path) self._buffer_size = DEFAULT_BUFFER_SIZE self._blob_storage = ( @@ -254,21 +261,34 @@ def _create_resource_provider( self, dataset_name: str, config_builder: DataDesignerConfigBuilder ) -> ResourceProvider: model_configs = config_builder.model_configs + seed_dataset_source_registry = self._create_seed_dataset_source_registry(config_builder) ArtifactStorage.mkdir_if_needed(self._artifact_path) - return ResourceProvider( + return create_resource_provider( artifact_storage=ArtifactStorage(artifact_path=self._artifact_path, dataset_name=dataset_name), - model_registry=create_model_registry( - model_configs=model_configs, - model_provider_registry=self._model_provider_registry, - secret_resolver=self._secret_resolver, - ), + model_configs=model_configs, + secret_resolver=self._secret_resolver, + model_provider_registry=self._model_provider_registry, + seed_dataset_source_registry=seed_dataset_source_registry, blob_storage=self._blob_storage, - datastore=( - LocalSeedDatasetDataStore() - if (settings := config_builder.get_seed_datastore_settings()) is None - else HfHubSeedDatasetDataStore( - endpoint=settings.endpoint, - token=settings.token, - ) - ), ) + + def _create_seed_dataset_source_registry( + self, config_builder: DataDesignerConfigBuilder + ) -> SeedDatasetSourceRegistry: + if (seed_config := config_builder.get_seed_config()) is None: + return SeedDatasetSourceRegistry(sources=[LocalSeedDatasetSource()]) + + reference = config_builder.get_seed_dataset_reference() + + if isinstance(reference, HfHubSeedDatasetReference): + source = HfHubSeedDatasetSource( + endpoint=reference.endpoint, + token=reference.token, + ) + else: + source = LocalSeedDatasetSource() + + if seed_config.source: + source.name = seed_config.source + + return SeedDatasetSourceRegistry(sources=[source]) diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 68a78e5f..668fdb28 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -25,12 +25,11 @@ ) from data_designer.config.config_builder import BuilderConfig, DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig -from data_designer.config.datastore import DatastoreSettings from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams -from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy +from data_designer.config.seed import HfHubSeedDatasetReference, SamplingStrategy from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams @@ -41,16 +40,13 @@ class DummyStructuredModel(BaseModel): @pytest.fixture def mock_fetch_seed_dataset_column_names(): - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch_seed: - mock_fetch_seed.return_value = ["id", "name", "age", "city"] - yield mock_fetch_seed + with patch.object(HfHubSeedDatasetReference, "get_column_names", return_value=["id", "name", "age", "city"]): + yield @pytest.fixture -def stub_data_designer_builder(stub_data_designer_builder_config_str): - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch_seed: - mock_fetch_seed.return_value = ["id", "name", "age", "city"] - yield DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) +def stub_data_designer_builder(stub_data_designer_builder_config_str, mock_fetch_seed_dataset_column_names): + yield DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) def test_loading_model_configs_in_constructor(stub_model_configs): @@ -635,12 +631,12 @@ def test_seed_config(stub_complete_builder): def test_with_seed_dataset_basic(stub_empty_builder, mock_fetch_seed_dataset_column_names): """Test with_seed_dataset method with basic parameters.""" - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age", "city"] - result = stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference(dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings) - ) + seed_dataset_reference = HfHubSeedDatasetReference( + dataset="test-repo/test-data.parquet", + endpoint="https://huggingface.co", + token="test-token", + ) + result = stub_empty_builder.with_seed_dataset(seed_dataset_reference) assert result is stub_empty_builder assert stub_empty_builder.get_seed_config().dataset == "test-repo/test-data.parquet" @@ -649,14 +645,16 @@ def test_with_seed_dataset_basic(stub_empty_builder, mock_fetch_seed_dataset_col def test_with_seed_dataset_sampling_strategy(stub_empty_builder, mock_fetch_seed_dataset_column_names): """Test with_seed_dataset with different sampling strategies.""" - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") + seed_dataset_reference = HfHubSeedDatasetReference( + dataset="test-repo/test-data.parquet", + endpoint="https://huggingface.co", + token="test-token", + ) - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age", "city"] - stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference(dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings), - sampling_strategy=SamplingStrategy.SHUFFLE, - ) + stub_empty_builder.with_seed_dataset( + seed_dataset_reference, + sampling_strategy=SamplingStrategy.SHUFFLE, + ) seed_config = stub_empty_builder.get_seed_config() assert seed_config.sampling_strategy == SamplingStrategy.SHUFFLE diff --git a/tests/config/test_datastore.py b/tests/config/test_datastore.py deleted file mode 100644 index 49955aae..00000000 --- a/tests/config/test_datastore.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import MagicMock, patch - -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq -import pytest - -from data_designer.config.datastore import ( - DatastoreSettings, - fetch_seed_dataset_column_names, - get_file_column_names, - resolve_datastore_settings, - upload_to_hf_hub, -) -from data_designer.config.errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError -from data_designer.config.seed import DatastoreSeedDatasetReference, LocalSeedDatasetReference - - -@pytest.fixture -def datastore_settings(): - return DatastoreSettings(endpoint="https://testing.com", token="stub-token") - - -def _write_file(df, path, file_type): - if file_type == "parquet": - df.to_parquet(path) - elif file_type in {"json", "jsonl"}: - df.to_json(path, orient="records", lines=True) - else: - df.to_csv(path, index=False) - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_basic_parquet(tmp_path, file_type): - """Test get_file_column_names with basic parquet file.""" - test_data = { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "age": [25, 30, 35], - "city": ["NYC", "LA", "Chicago"], - } - df = pd.DataFrame(test_data) - - parquet_path = tmp_path / f"test_data.{file_type}" - _write_file(df, parquet_path, file_type) - assert get_file_column_names(str(parquet_path), file_type) == df.columns.tolist() - - -def test_get_file_column_names_nested_fields(tmp_path): - """Test get_file_column_names with nested fields in parquet.""" - schema = pa.schema( - [ - pa.field( - "nested", pa.struct([pa.field("col1", pa.list_(pa.int32())), pa.field("col2", pa.list_(pa.int32()))]) - ), - ] - ) - - # For PyArrow, we need to structure the data as a list of records - nested_data = {"nested": [{"col1": [1, 2, 3], "col2": [4, 5, 6]}]} - nested_path = tmp_path / "nested_fields.parquet" - pq.write_table(pa.Table.from_pydict(nested_data, schema=schema), nested_path) - - column_names = get_file_column_names(str(nested_path), "parquet") - - assert column_names == ["nested"] - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_empty_parquet(tmp_path, file_type): - """Test get_file_column_names with empty parquet file.""" - empty_df = pd.DataFrame() - empty_path = tmp_path / f"empty.{file_type}" - _write_file(empty_df, empty_path, file_type) - - column_names = get_file_column_names(str(empty_path), file_type) - assert column_names == [] - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_large_schema(tmp_path, file_type): - """Test get_file_column_names with many columns.""" - num_columns = 50 - test_data = {f"col_{i}": np.random.randn(10) for i in range(num_columns)} - df = pd.DataFrame(test_data) - - large_path = tmp_path / f"large_schema.{file_type}" - _write_file(df, large_path, file_type) - - column_names = get_file_column_names(str(large_path), file_type) - assert len(column_names) == num_columns - assert column_names == [f"col_{i}" for i in range(num_columns)] - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_special_characters(tmp_path, file_type): - """Test get_file_column_names with special characters in column names.""" - special_data = { - "column with spaces": [1], - "column-with-dashes": [2], - "column_with_underscores": [3], - "column.with.dots": [4], - "column123": [5], - "123column": [6], - "column!@#$%^&*()": [7], - } - df_special = pd.DataFrame(special_data) - special_path = tmp_path / f"special_chars.{file_type}" - _write_file(df_special, special_path, file_type) - - assert get_file_column_names(str(special_path), file_type) == df_special.columns.tolist() - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_unicode(tmp_path, file_type): - """Test get_file_column_names with unicode column names.""" - unicode_data = {"café": [1], "résumé": [2], "naïve": [3], "façade": [4], "garçon": [5], "über": [6], "schön": [7]} - df_unicode = pd.DataFrame(unicode_data) - - unicode_path = tmp_path / f"unicode_columns.{file_type}" - _write_file(df_unicode, unicode_path, file_type) - assert get_file_column_names(str(unicode_path), file_type) == df_unicode.columns.tolist() - - -@pytest.mark.parametrize("file_type", ["parquet", "csv", "json", "jsonl"]) -def test_get_file_column_names_with_glob_pattern(tmp_path, file_type): - df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) - for i in range(5): - _write_file(df, tmp_path / f"{i}.{file_type}", file_type) - assert get_file_column_names(f"{tmp_path}/*.{file_type}", file_type) == ["col1", "col2"] - - -def test_get_file_column_names_with_glob_pattern_error(tmp_path): - df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) - for i in range(5): - _write_file(df, tmp_path / f"{i}.parquet", "parquet") - with pytest.raises(InvalidFilePathError, match="No files found matching pattern"): - get_file_column_names(f"{tmp_path}/*.csv", "csv") - - -def test_get_file_column_names_error_handling(): - with pytest.raises(InvalidFilePathError, match="🛑 Unsupported file type: 'txt'"): - get_file_column_names("test.txt", "txt") - - with patch("data_designer.config.datastore.pq.read_schema") as mock_read_schema: - mock_read_schema.side_effect = Exception("Test error") - assert get_file_column_names("test.txt", "parquet") == [] - - with patch("data_designer.config.datastore.pq.read_schema") as mock_read_schema: - mock_col1 = MagicMock() - mock_col1.name = "col1" - mock_col2 = MagicMock() - mock_col2.name = "col2" - mock_read_schema.return_value = [mock_col1, mock_col2] - assert get_file_column_names("test.txt", "parquet") == ["col1", "col2"] - - -def test_fetch_seed_dataset_column_names_parquet_error_handling(datastore_settings): - with pytest.raises(InvalidFileFormatError, match="🛑 Unsupported file type: 'test.txt'"): - fetch_seed_dataset_column_names( - DatastoreSeedDatasetReference( - dataset="test/repo/test.txt", - datastore_settings=datastore_settings, - ) - ) - - -@patch("data_designer.config.datastore.get_file_column_names", autospec=True) -def test_fetch_seed_dataset_column_names_local_file(mock_get_file_column_names, datastore_settings): - mock_get_file_column_names.return_value = ["col1", "col2"] - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - assert fetch_seed_dataset_column_names(LocalSeedDatasetReference(dataset="test.parquet")) == ["col1", "col2"] - - -@patch("data_designer.config.datastore.HfFileSystem.open") -@patch("data_designer.config.datastore.get_file_column_names", autospec=True) -def test_fetch_seed_dataset_column_names_remote_file(mock_get_file_column_names, mock_hf_fs_open, datastore_settings): - mock_get_file_column_names.return_value = ["col1", "col2"] - assert fetch_seed_dataset_column_names( - DatastoreSeedDatasetReference( - dataset="test/repo/test.parquet", - datastore_settings=datastore_settings, - ) - ) == ["col1", "col2"] - mock_hf_fs_open.assert_called_once_with( - "datasets/test/repo/test.parquet", - ) - - -def test_resolve_datastore_settings(datastore_settings): - with pytest.raises(InvalidConfigError, match="Datastore settings are required"): - resolve_datastore_settings(None) - - with pytest.raises(InvalidConfigError, match="Invalid datastore settings format"): - resolve_datastore_settings("invalid_settings") - - assert resolve_datastore_settings(datastore_settings) == datastore_settings - assert resolve_datastore_settings(datastore_settings.model_dump()) == datastore_settings - - -@patch("data_designer.config.datastore.HfApi.upload_file", autospec=True) -@patch("data_designer.config.datastore.HfApi.create_repo", autospec=True) -def test_upload_to_hf_hub(mock_create_repo, mock_upload_file, datastore_settings): - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - - assert ( - upload_to_hf_hub("test.parquet", "test.parquet", "test/repo", datastore_settings) - == "test/repo/test.parquet" - ) - mock_create_repo.assert_called_once() - mock_upload_file.assert_called_once() - - -def test_upload_to_hf_hub_error_handling(datastore_settings): - with pytest.raises( - InvalidFilePathError, match="To upload a dataset to the datastore, you must provide a valid file path." - ): - upload_to_hf_hub("test.txt", "test.txt", "test/repo", datastore_settings) - - with pytest.raises( - InvalidFileFormatError, match="Dataset file extension '.parquet' does not match `filename` extension .'csv'" - ): - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - upload_to_hf_hub("test.parquet", "test.csv", "test/repo", datastore_settings) - - with pytest.raises(InvalidFileFormatError, match="Dataset files must be in "): - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - upload_to_hf_hub("test.text", "test.txt", "test/repo", datastore_settings) diff --git a/tests/config/test_seed.py b/tests/config/test_seed.py index 42acf189..86c0957b 100644 --- a/tests/config/test_seed.py +++ b/tests/config/test_seed.py @@ -2,12 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +from typing import Union +from unittest.mock import MagicMock, patch +import numpy as np import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import pytest -from data_designer.config.errors import InvalidFilePathError -from data_designer.config.seed import IndexRange, LocalSeedDatasetReference, PartitionBlock +from data_designer.config.errors import InvalidFileFormatError, InvalidFilePathError +from data_designer.config.seed import HfHubSeedDatasetReference, IndexRange, LocalSeedDatasetReference, PartitionBlock def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path: @@ -99,3 +104,192 @@ def test_local_seed_dataset_reference_validation_error(tmp_path: Path): create_partitions_in_path(tmp_path, "parquet") with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"): LocalSeedDatasetReference(dataset=f"{tmp_path}/*.csv") + + +def test_local_seed_dataset_reference_file_format_error(tmp_path: Path): + df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) + filepath = tmp_path / "test.txt" + df.to_csv(filepath) + + with pytest.raises(InvalidFileFormatError): + LocalSeedDatasetReference(dataset=filepath) + + +def _write_file(df: pd.DataFrame, path: Union[str, Path], file_type: str): + if file_type == "parquet": + df.to_parquet(path) + elif file_type in {"json", "jsonl"}: + df.to_json(path, orient="records", lines=True) + else: + df.to_csv(path, index=False) + + +@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) +def test_get_column_names(tmp_path, file_type): + """Test get_file_column_names with basic parquet file.""" + test_data = { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + "city": ["NYC", "LA", "Chicago"], + } + df = pd.DataFrame(test_data) + + parquet_path = tmp_path / f"test_data.{file_type}" + _write_file(df, parquet_path, file_type) + + reference = LocalSeedDatasetReference(dataset=parquet_path) + assert reference.get_column_names() == df.columns.tolist() + + +def test_get_file_column_names_nested_fields(tmp_path): + """Test get_file_column_names with nested fields in parquet.""" + schema = pa.schema( + [ + pa.field( + "nested", pa.struct([pa.field("col1", pa.list_(pa.int32())), pa.field("col2", pa.list_(pa.int32()))]) + ), + ] + ) + + # For PyArrow, we need to structure the data as a list of records + nested_data = {"nested": [{"col1": [1, 2, 3], "col2": [4, 5, 6]}]} + nested_path = tmp_path / "nested_fields.parquet" + pq.write_table(pa.Table.from_pydict(nested_data, schema=schema), nested_path) + + reference = LocalSeedDatasetReference(dataset=nested_path) + column_names = reference.get_column_names() + + assert column_names == ["nested"] + + +@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) +def test_get_file_column_names_empty_parquet(tmp_path, file_type): + """Test get_file_column_names with empty parquet file.""" + empty_df = pd.DataFrame() + empty_path = tmp_path / f"empty.{file_type}" + _write_file(empty_df, empty_path, file_type) + + reference = LocalSeedDatasetReference(dataset=empty_path) + column_names = reference.get_column_names() + + assert column_names == [] + + +@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) +def test_get_file_column_names_large_schema(tmp_path, file_type): + """Test get_file_column_names with many columns.""" + num_columns = 50 + test_data = {f"col_{i}": np.random.randn(10) for i in range(num_columns)} + df = pd.DataFrame(test_data) + + large_path = tmp_path / f"large_schema.{file_type}" + _write_file(df, large_path, file_type) + + reference = LocalSeedDatasetReference(dataset=large_path) + column_names = reference.get_column_names() + + assert len(column_names) == num_columns + assert column_names == [f"col_{i}" for i in range(num_columns)] + + +@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) +def test_get_file_column_names_special_characters(tmp_path, file_type): + """Test get_file_column_names with special characters in column names.""" + special_data = { + "column with spaces": [1], + "column-with-dashes": [2], + "column_with_underscores": [3], + "column.with.dots": [4], + "column123": [5], + "123column": [6], + "column!@#$%^&*()": [7], + } + df_special = pd.DataFrame(special_data) + special_path = tmp_path / f"special_chars.{file_type}" + _write_file(df_special, special_path, file_type) + + reference = LocalSeedDatasetReference(dataset=special_path) + column_names = reference.get_column_names() + + assert column_names == df_special.columns.tolist() + + +@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) +def test_get_file_column_names_unicode(tmp_path, file_type): + """Test get_file_column_names with unicode column names.""" + unicode_data = {"café": [1], "résumé": [2], "naïve": [3], "façade": [4], "garçon": [5], "über": [6], "schön": [7]} + df_unicode = pd.DataFrame(unicode_data) + + unicode_path = tmp_path / f"unicode_columns.{file_type}" + _write_file(df_unicode, unicode_path, file_type) + + reference = LocalSeedDatasetReference(dataset=unicode_path) + column_names = reference.get_column_names() + + assert column_names == df_unicode.columns.tolist() + + +@pytest.mark.parametrize("file_type", ["parquet", "csv", "json", "jsonl"]) +def test_get_file_column_names_with_glob_pattern(tmp_path, file_type): + df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) + for i in range(5): + _write_file(df, tmp_path / f"{i}.{file_type}", file_type) + + reference = LocalSeedDatasetReference(dataset=f"{tmp_path}/*.{file_type}") + column_names = reference.get_column_names() + + assert column_names == ["col1", "col2"] + + +def test_get_file_column_names_error_handling(tmp_path): + df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) + filepath = tmp_path / "test.parquet" + df.to_parquet(filepath) + + reference = LocalSeedDatasetReference(dataset=filepath) + + with patch("data_designer.config.seed.pq.read_schema") as mock_read_schema: + mock_read_schema.side_effect = Exception("Test error") + reference.get_column_names() + + with patch("data_designer.config.seed.pq.read_schema") as mock_read_schema: + mock_col1 = MagicMock() + mock_col1.name = "col1" + mock_col2 = MagicMock() + mock_col2.name = "col2" + mock_read_schema.return_value = [mock_col1, mock_col2] + + column_names = reference.get_column_names() + assert column_names == ["col1", "col2"] + + +TEST_ENDPOINT = "https://testing.com" +TEST_TOKEN = "stub-token" + + +def test_fetch_seed_dataset_column_names_parquet_error_handling(): + reference = HfHubSeedDatasetReference( + dataset="test/repo/test.txt", + endpoint=TEST_ENDPOINT, + token=TEST_TOKEN, + ) + with pytest.raises(InvalidFileFormatError, match="🛑 Unsupported file type: 'test.txt'"): + reference.get_column_names() + + +@patch("data_designer.config.seed.HfFileSystem.open") +@patch("data_designer.config.seed._get_file_column_names", autospec=True) +def test_fetch_seed_dataset_column_names_remote_file(mock_get_file_column_names, mock_hf_fs_open): + mock_get_file_column_names.return_value = ["col1", "col2"] + + reference = HfHubSeedDatasetReference( + dataset="test/repo/test.parquet", + endpoint=TEST_ENDPOINT, + token=TEST_TOKEN, + ) + + assert reference.get_column_names() == ["col1", "col2"] + mock_hf_fs_open.assert_called_once_with( + "datasets/test/repo/test.parquet", + ) diff --git a/tests/config/utils/test_visualization.py b/tests/config/utils/test_visualization.py index aaa3201a..4d23665e 100644 --- a/tests/config/utils/test_visualization.py +++ b/tests/config/utils/test_visualization.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch import pandas as pd import pytest @@ -26,21 +25,18 @@ def validation_output(): @pytest.fixture def config_builder_with_validation(): """Fixture providing a DataDesignerConfigBuilder with a validation column.""" - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["code"] - - builder = DataDesignerConfigBuilder() - - # Add a validation column configuration - builder.add_column( - name="code_validation_result", - column_type="validation", - target_columns=["code"], - validator_type="code", - validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), - ) - - return builder + builder = DataDesignerConfigBuilder() + + # Add a validation column configuration + builder.add_column( + name="code_validation_result", + column_type="validation", + target_columns=["code"], + validator_type="code", + validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), + ) + + return builder def test_display_sample_record_twice_no_errors(validation_output, config_builder_with_validation): diff --git a/tests/conftest.py b/tests/conftest.py index 62ecc60a..e7bed7a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,8 +16,8 @@ from data_designer.config.columns import SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig -from data_designer.config.datastore import DatastoreSettings from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.seed import HfHubSeedDatasetReference @pytest.fixture @@ -117,9 +117,11 @@ def stub_data_designer_builder_config_str(stub_data_designer_config_str: str) -> data_designer: {textwrap.indent(stub_data_designer_config_str, prefix=" ")} -datastore_settings: +seed_dataset_reference: + dataset: test-repo/testing/data.csv endpoint: http://test-endpoint:3000/v1/hf token: stub-token + reference_type: hf_hub """ @@ -152,17 +154,10 @@ def stub_empty_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerCon @pytest.fixture def stub_complete_builder(stub_data_designer_builder_config_str: str) -> DataDesignerConfigBuilder: - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age", "city"] + with patch.object(HfHubSeedDatasetReference, "get_column_names", return_value=["id", "name", "age", "city"]): return DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) -@pytest.fixture -def stub_datastore_settings(): - """Test datastore settings with testing endpoint and token.""" - return DatastoreSettings(endpoint="https://testing.com", token="stub-token") - - @pytest.fixture def stub_dataframe(): return pd.DataFrame( diff --git a/tests/engine/analysis/conftest.py b/tests/engine/analysis/conftest.py index 7d571d5f..8c9bebe1 100644 --- a/tests/engine/analysis/conftest.py +++ b/tests/engine/analysis/conftest.py @@ -24,7 +24,9 @@ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.models.registry import ModelRegistry from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry +from data_designer.engine.resources.managed_storage import ManagedBlobStorage from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.resources.seed_dataset_source import SeedDatasetRepository @fixture @@ -78,7 +80,12 @@ def dataset_profiler( profiler = DataDesignerDatasetProfiler( config=DatasetProfilerConfig(column_configs=column_configs), - resource_provider=ResourceProvider(artifact_storage=artifact_storage, model_registry=model_registry), + resource_provider=ResourceProvider( + artifact_storage=artifact_storage, + model_registry=model_registry, + seed_dataset_repository=Mock(spec=SeedDatasetRepository), + blob_storage=Mock(spec=ManagedBlobStorage), + ), ) return profiler @@ -144,9 +151,3 @@ def stub_judge_distributions(): distributions={"quality": NumericalDistribution(min=0, max=4, mean=2.0, stddev=1.4, median=2.0)}, histograms={"quality": CategoricalHistogramData(categories=[4, 3, 2, 1, 0], counts=[1, 1, 1, 1, 1])}, ) - - -@fixture -def stub_resource_provider_no_model_registry(tmp_path): - """Create a mock ResourceProvider for testing.""" - return ResourceProvider(artifact_storage=ArtifactStorage(artifact_path=tmp_path)) diff --git a/tests/engine/analysis/test_dataset_profiler.py b/tests/engine/analysis/test_dataset_profiler.py index 154f6360..b3ae53e1 100644 --- a/tests/engine/analysis/test_dataset_profiler.py +++ b/tests/engine/analysis/test_dataset_profiler.py @@ -10,7 +10,6 @@ from data_designer.config.sampler_params import CategorySamplerParams, SamplerType from data_designer.engine.analysis.column_profilers.judge_score_profiler import JudgeScoreProfilerConfig from data_designer.engine.analysis.dataset_profiler import ( - DataDesignerDatasetProfiler, DatasetProfilerConfig, ) from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError @@ -91,48 +90,6 @@ def test_dataset_profiler_profile_dataset_with_column_profilers( stub_model_facade.generate.assert_called() -@patch( - "data_designer.engine.analysis.dataset_profiler.DataDesignerDatasetProfiler._validate_schema_consistency", - autospec=True, -) -def test_dataset_profiler_requires_model_registry_with_column_profiler_configs( - mock_validate_schema_consistency, stub_resource_provider_no_model_registry -): - column_configs = [ - SamplerColumnConfig( - name="test_id", - sampler_type=SamplerType.CATEGORY, - params=CategorySamplerParams(values=["a", "b", "c"]), - ), - ] - - mock_validate_schema_consistency.return_value = None - - DataDesignerDatasetProfiler( - config=DatasetProfilerConfig( - column_configs=column_configs, - ), - resource_provider=stub_resource_provider_no_model_registry, - ) - - with pytest.raises( - DatasetProfilerConfigurationError, - match="Model registry is required for column profiler configs", - ): - DataDesignerDatasetProfiler( - config=DatasetProfilerConfig( - column_configs=column_configs, - column_profiler_configs=[ - JudgeScoreProfilerConfig( - model_alias="model-alias", - summary_score_sample_size=5, - ) - ], - ), - resource_provider=stub_resource_provider_no_model_registry, - ) - - def test_profile_dataset_no_applicable_column_types(dataset_profiler, stub_df, stub_model_facade): mock_profiler = Mock() mock_profiler.metadata.return_value.applicable_column_types = ["NONEXISTENT_COLUMN_TYPE"] diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index ebb9c72a..0bda805b 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -18,7 +18,7 @@ ) from data_designer.engine.column_generators.utils.errors import SeedDatasetError from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig -from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType +from data_designer.engine.resources.resource_provider import ResourceType @pytest.fixture @@ -36,9 +36,9 @@ def stub_seed_dataset_config(): @pytest.fixture def stub_seed_dataset_generator(stub_resource_provider, stub_duckdb_conn, stub_seed_dataset_config): mock_provider = stub_resource_provider - mock_datastore = mock_provider.datastore - mock_datastore.create_duckdb_connection.return_value = stub_duckdb_conn - mock_datastore.get_dataset_uri.return_value = "test_uri" + mock_seed_dataset_repository = mock_provider.seed_dataset_repository + mock_seed_dataset_repository.create_duckdb_connection.return_value = stub_duckdb_conn + mock_seed_dataset_repository.get_dataset_uri.return_value = "test_uri" return SeedDatasetColumnGenerator(config=stub_seed_dataset_config, resource_provider=mock_provider) @@ -107,7 +107,7 @@ def seed_dataset_jsonl(sample_dataframe): def test_seed_dataset_column_generator_metadata(): metadata = SeedDatasetColumnGenerator.metadata() assert metadata.generation_strategy == GenerationStrategy.FULL_COLUMN - assert ResourceType.DATASTORE in metadata.required_resources + assert ResourceType.SEED_DATASET_REPOSITORY in metadata.required_resources def test_seed_dataset_column_generator_config_structure(): @@ -328,7 +328,7 @@ def test_seed_dataset_column_generator_sample_records_zero_record_error(stub_see gen._batch_reader = Mock() gen._batch_reader.read_next_batch.return_value = mock_batch - with pytest.raises(RuntimeError, match="🛑 Something went wrong while reading from the datastore"): + with pytest.raises(RuntimeError, match="🛑 Something went wrong while reading from the seed dataset source"): gen._sample_records(3) @@ -360,7 +360,7 @@ def test_seed_dataset_column_generator_sample_records_multiple_batches(stub_seed def create_generator_with_real_file( file_path: str, - stub_resource_provider: ResourceProvider, + stub_resource_provider, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, selection_strategy: IndexRange | PartitionBlock | None = None, ) -> SeedDatasetColumnGenerator: @@ -382,9 +382,9 @@ def create_generator_with_real_file( real_conn = duckdb.connect() mock_provider = stub_resource_provider - mock_datastore = mock_provider.datastore - mock_datastore.create_duckdb_connection.return_value = real_conn - mock_datastore.get_dataset_uri.return_value = file_path + mock_seed_dataset_repository = mock_provider.seed_dataset_repository + mock_seed_dataset_repository.create_duckdb_connection.return_value = real_conn + mock_seed_dataset_repository.get_dataset_uri.return_value = file_path generator = SeedDatasetColumnGenerator(config=config, resource_provider=mock_provider) return generator @@ -441,8 +441,8 @@ def test_seed_dataset_generator_ordered_sampling(fixture_name, stub_resource_pro real_conn = duckdb.connect() mock_provider = stub_resource_provider - mock_provider.datastore.create_duckdb_connection.return_value = real_conn - mock_provider.datastore.get_dataset_uri.return_value = file_path + mock_provider.seed_dataset_repository.create_duckdb_connection.return_value = real_conn + mock_provider.seed_dataset_repository.get_dataset_uri.return_value = file_path generator = SeedDatasetColumnGenerator(config=config, resource_provider=mock_provider) @@ -477,8 +477,8 @@ def test_seed_dataset_generator_shuffle_sampling(fixture_name, stub_resource_pro real_conn = duckdb.connect() mock_provider = stub_resource_provider - mock_provider.datastore.create_duckdb_connection.return_value = real_conn - mock_provider.datastore.get_dataset_uri.return_value = file_path + mock_provider.seed_dataset_repository.create_duckdb_connection.return_value = real_conn + mock_provider.seed_dataset_repository.get_dataset_uri.return_value = file_path generator = SeedDatasetColumnGenerator(config=config, resource_provider=mock_provider) diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py index dc30ba25..8769d344 100644 --- a/tests/engine/conftest.py +++ b/tests/engine/conftest.py @@ -11,6 +11,7 @@ from data_designer.engine.models.registry import ModelRegistry from data_designer.engine.resources.managed_storage import ManagedBlobStorage from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.resources.seed_dataset_source import SeedDatasetRepository @pytest.fixture @@ -35,7 +36,7 @@ def stub_resource_provider(tmp_path, stub_model_facade): mock_provider.model_registry = mock_model_registry mock_provider.artifact_storage = ArtifactStorage(artifact_path=tmp_path) mock_provider.blob_storage = Mock(spec=ManagedBlobStorage) - mock_provider.datastore = Mock() + mock_provider.seed_dataset_repository = Mock(spec=SeedDatasetRepository) return mock_provider diff --git a/tests/engine/resources/test_resource_provider.py b/tests/engine/resources/test_resource_provider.py deleted file mode 100644 index 5d7b07a8..00000000 --- a/tests/engine/resources/test_resource_provider.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import inspect -from unittest.mock import Mock, patch - -import pytest - -from data_designer.engine.resources.resource_provider import ( - ResourceProvider, - create_resource_provider, -) - - -def test_resource_provider_artifact_storage_required(): - with pytest.raises(ValueError, match="Field required"): - ResourceProvider() - - -@pytest.mark.parametrize( - "test_case,expected_error", - [ - ("model_registry_creation_error", "Model registry creation failed"), - ], -) -def test_create_resource_provider_error_cases(test_case, expected_error): - mock_artifact_storage = Mock() - mock_model_configs = [Mock(), Mock()] - mock_secret_resolver = Mock() - mock_model_provider_registry = Mock() - - with patch("data_designer.engine.resources.resource_provider.create_model_registry") as mock_create_model_registry: - mock_create_model_registry.side_effect = Exception(expected_error) - - with pytest.raises(Exception, match=expected_error): - create_resource_provider( - artifact_storage=mock_artifact_storage, - model_configs=mock_model_configs, - secret_resolver=mock_secret_resolver, - model_provider_registry=mock_model_provider_registry, - ) - - -def test_create_resource_provider_function_exists(): - assert callable(create_resource_provider) - - sig = inspect.signature(create_resource_provider) - params = list(sig.parameters.keys()) - - expected_params = [ - "artifact_storage", - "model_configs", - "secret_resolver", - "model_provider_registry", - "datastore", - "blob_storage", - ] - - for param in expected_params: - assert param in params diff --git a/tests/engine/resources/test_seed_dataset_source.py b/tests/engine/resources/test_seed_dataset_source.py new file mode 100644 index 00000000..3f9410b6 --- /dev/null +++ b/tests/engine/resources/test_seed_dataset_source.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest + +from data_designer.engine.errors import SecretResolutionError +from data_designer.engine.resources.seed_dataset_source import ( + HfHubSeedDatasetSource, + SeedDatasetRepository, + SeedDatasetSourceRegistry, +) +from data_designer.engine.secret_resolver import EnvironmentResolver + +HF_ENDPOINT = "https://huggingface.co" +NDS_ENDPOINT = "http://datastore:3000/v1/hf" + + +def test_hf_hub_source_resolution(monkeypatch: pytest.MonkeyPatch): + secret_resolver = EnvironmentResolver() + + token_ref = "MY_HF_TOKEN" + token_raw_value = "token-raw-value" + hf_source = HfHubSeedDatasetSource(endpoint=HF_ENDPOINT, token=token_ref) + + with pytest.raises(SecretResolutionError): + hf_source.resolve(secret_resolver) + + monkeypatch.setenv(token_ref, token_raw_value) + resolved_hf_source = hf_source.resolve(secret_resolver) + + assert resolved_hf_source.token == token_raw_value + + +def test_registry_from_simple_config(): + config = { + "sources": [ + {"source_type": "hf_hub", "endpoint": HF_ENDPOINT}, + ], + } + + registry = SeedDatasetSourceRegistry.model_validate_json(json.dumps(config)) + assert len(registry.sources) == 1 + assert isinstance(registry.sources[0], HfHubSeedDatasetSource) + assert registry.sources[0].name == "hf_hub" + assert registry.sources[0].token is None + + +def test_registry_from_more_complex_config(): + config = { + "default": "hf", + "sources": [ + {"source_type": "hf_hub", "name": "hf", "endpoint": HF_ENDPOINT, "token": "HF_TOKEN"}, + {"source_type": "hf_hub", "name": "nds", "endpoint": NDS_ENDPOINT}, + ], + } + + registry = SeedDatasetSourceRegistry.model_validate_json(json.dumps(config)) + assert len(registry.sources) == 2 + assert isinstance(registry.sources[0], HfHubSeedDatasetSource) + assert isinstance(registry.sources[1], HfHubSeedDatasetSource) + + +def test_registry_validation_errors(): + with pytest.raises(ValueError) as excinfo: + SeedDatasetSourceRegistry( + sources=[ + HfHubSeedDatasetSource(name="hf", endpoint=HF_ENDPOINT), + HfHubSeedDatasetSource(name="nds", endpoint=NDS_ENDPOINT), + ] + ) + assert "default source" in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + SeedDatasetSourceRegistry( + sources=[ + HfHubSeedDatasetSource(name="name", endpoint=HF_ENDPOINT), + HfHubSeedDatasetSource(name="name", endpoint=NDS_ENDPOINT), + ] + ) + assert "duplicates" in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + SeedDatasetSourceRegistry( + default="not-defined", + sources=[HfHubSeedDatasetSource(endpoint=HF_ENDPOINT)], + ) + assert "not found" in str(excinfo.value) + + +def test_get_uri_through_repository(): + source_name = "source-name" + registry = SeedDatasetSourceRegistry( + sources=[HfHubSeedDatasetSource(name=source_name, endpoint=HF_ENDPOINT)], + ) + secret_resolver = EnvironmentResolver() + repository = SeedDatasetRepository(registry=registry, secret_resolver=secret_resolver) + + file_id = "namespace/repo/file.parquet" + + # The registry only has one source defined, so it has an implicit default + uri1 = repository.get_dataset_uri(file_id, source_name) + uri2 = repository.get_dataset_uri(file_id, None) + + assert uri1 == uri2 == "hf://datasets/namespace/repo/file.parquet" diff --git a/tests/engine/test_configurable_task.py b/tests/engine/test_configurable_task.py index 1210b448..61cb0819 100644 --- a/tests/engine/test_configurable_task.py +++ b/tests/engine/test_configurable_task.py @@ -16,7 +16,9 @@ ) from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.models.registry import ModelRegistry +from data_designer.engine.resources.managed_storage import ManagedBlobStorage from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType +from data_designer.engine.resources.seed_dataset_source import SeedDatasetRepository def test_configurable_task_metadata_creation(): @@ -68,7 +70,12 @@ def _initialize(self) -> None: mock_artifact_storage.final_dataset_folder_name = "final_dataset" mock_artifact_storage.partial_results_folder_name = "partial_results" mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage) + resource_provider = ResourceProvider( + artifact_storage=mock_artifact_storage, + model_registry=Mock(spec=ModelRegistry), + seed_dataset_repository=Mock(spec=SeedDatasetRepository), + blob_storage=Mock(spec=ManagedBlobStorage), + ) task = TestTask(config=config, resource_provider=resource_provider) @@ -99,7 +106,12 @@ def _validate(self) -> None: mock_artifact_storage.final_dataset_folder_name = "final_dataset" mock_artifact_storage.partial_results_folder_name = "partial_results" mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage) + resource_provider = ResourceProvider( + artifact_storage=mock_artifact_storage, + model_registry=Mock(spec=ModelRegistry), + seed_dataset_repository=Mock(spec=SeedDatasetRepository), + blob_storage=Mock(spec=ManagedBlobStorage), + ) task = TestTask(config=config, resource_provider=resource_provider) assert task._config.value == "test" @@ -138,7 +150,12 @@ def _initialize(self) -> None: mock_artifact_storage.partial_results_folder_name = "partial_results" mock_artifact_storage.dropped_columns_folder_name = "dropped_columns" mock_model_registry = Mock(spec=ModelRegistry) - resource_provider = ResourceProvider(artifact_storage=mock_artifact_storage, model_registry=mock_model_registry) + resource_provider = ResourceProvider( + artifact_storage=mock_artifact_storage, + model_registry=mock_model_registry, + seed_dataset_repository=Mock(spec=SeedDatasetRepository), + blob_storage=Mock(spec=ManagedBlobStorage), + ) task = TestTask(config=config, resource_provider=resource_provider) assert task._resource_provider == resource_provider diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py index 89f8388a..9226da48 100644 --- a/tests/essentials/test_init.py +++ b/tests/essentials/test_init.py @@ -20,8 +20,6 @@ DataDesignerColumnType, DataDesignerConfig, DataDesignerConfigBuilder, - DatastoreSeedDatasetReference, - DatastoreSettings, DatetimeSamplerParams, ExpressionColumnConfig, GaussianSamplerParams, @@ -81,7 +79,6 @@ def test_config_imports(): """Test config-related imports""" assert DataDesignerConfig is not None assert DataDesignerConfigBuilder is not None - assert DatastoreSettings is not None assert isinstance(can_run_data_designer_locally(), bool) @@ -145,7 +142,6 @@ def test_sampler_params_imports(): def test_seed_config_imports(): """Test seed configuration imports""" - assert DatastoreSeedDatasetReference is not None assert SamplingStrategy is not None assert SeedConfig is not None @@ -217,7 +213,6 @@ def test_all_contains_config_classes(): """Test __all__ contains config classes""" assert "DataDesignerConfig" in __all__ assert "DataDesignerConfigBuilder" in __all__ - assert "DatastoreSettings" in __all__ def test_all_contains_column_configs(): @@ -275,7 +270,6 @@ def test_all_contains_model_configs(): def test_all_contains_seed_configs(): """Test __all__ contains seed configuration classes""" - assert "DatastoreSeedDatasetReference" in __all__ assert "SamplingStrategy" in __all__ assert "SeedConfig" in __all__