Skip to content

Commit bff760f

Browse files
committed
ignore if not needed
1 parent 8c8d543 commit bff760f

File tree

5 files changed

+47
-12
lines changed

5 files changed

+47
-12
lines changed

src/data_designer/engine/column_generators/generators/samplers.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from functools import partial
55
import logging
6+
from pathlib import Path
67
import random
78
from typing import Callable
89

@@ -16,8 +17,8 @@
1617
)
1718
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
1819
from data_designer.engine.processing.utils import concat_datasets
20+
from data_designer.engine.resources.errors import ManagedAssetMissingError, ManagedAssetsPathNotSetError
1921
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
20-
from data_designer.engine.resources.resource_provider import ResourceType
2122
from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
2223
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
2324
from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
@@ -32,7 +33,7 @@ def metadata() -> GeneratorMetadata:
3233
name="sampler_column_generator",
3334
description="Generate columns using sampling-based method.",
3435
generation_strategy=GenerationStrategy.FULL_COLUMN,
35-
required_resources=[ResourceType.BLOB_STORAGE],
36+
required_resources=None,
3637
)
3738

3839
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
@@ -52,7 +53,32 @@ def _needs_person_generator(self) -> bool:
5253
def _person_generator_loader(self) -> Callable[[bool], ManagedDatasetGenerator]:
5354
return partial(load_person_data_sampler, blob_storage=self.resource_provider.blob_storage)
5455

56+
def _check_managed_assets_exist_if_needed(self) -> None:
57+
if self._needs_person_generator:
58+
if (
59+
self.resource_provider.blob_storage is None
60+
or not self.resource_provider.blob_storage.root_path.exists()
61+
):
62+
raise ManagedAssetsPathNotSetError(
63+
"🛑 The managed assets path does not exist. If you are using the Person Sampler, "
64+
"You must have a managed assets directory that contains the Nemotron-Personas dataset "
65+
"for each locale you want to sample from."
66+
)
67+
is_missing = []
68+
for c in [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]:
69+
locale_file_path = self.resource_provider.blob_storage.root_path / f"datasets/{c.params.locale}.parquet"
70+
if not Path(locale_file_path).exists() or not Path(locale_file_path).is_file():
71+
is_missing.append([c.params.locale, locale_file_path])
72+
if len(is_missing) > 0:
73+
raise ManagedAssetMissingError(
74+
"🛑 The Nemotron-Personas dataset is missing for the following locales: "
75+
f"{', '.join([f'{locale}' for locale, _ in is_missing])}. "
76+
"Please ensure the files exist at the following paths: "
77+
f"{', '.join([f'{str(file_path)!r}' for _, file_path in is_missing])}"
78+
)
79+
5580
def _create_sampling_dataset_generator(self) -> SamplingDatasetGenerator:
81+
self._check_managed_assets_exist_if_needed()
5682
return SamplingDatasetGenerator(
5783
sampler_columns=self.config,
5884
person_generator_loader=(self._person_generator_loader if self._needs_person_generator else None),
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from data_designer.engine.errors import DataDesignerError
2+
3+
4+
class ResourceError(DataDesignerError):
5+
"""Base exception for all errors related to resources."""
6+
7+
8+
class ManagedAssetsPathNotSetError(ResourceError):
9+
"""Exception for all errors related to the managed assets path not being set."""
10+
11+
12+
class ManagedAssetMissingError(ResourceError):
13+
"""Exception for all errors related to the managed assets missing."""

src/data_designer/engine/resources/managed_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ class LocalBlobStorageProvider(ManagedBlobStorage):
4343
"""
4444

4545
def __init__(self, root_path: Path) -> None:
46-
self._root_path = root_path
46+
self.root_path = root_path
4747

4848
@contextmanager
4949
def get_blob(self, blob_key: str) -> Iterator[IO]:
5050
with open(self._key_uri_builder(blob_key), "rb") as fd:
5151
yield fd
5252

5353
def _key_uri_builder(self, key: str) -> str:
54-
return f"{self._root_path}/{key}"
54+
return f"{self.root_path}/{key}"
5555

5656

5757
def init_managed_blob_storage(assets_storage: str) -> ManagedBlobStorage:

src/data_designer/interface/data_designer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def _create_resource_provider(
339339
model_provider_registry=self._model_provider_registry,
340340
secret_resolver=self._secret_resolver,
341341
),
342-
blob_storage=init_managed_blob_storage(str(self._resolve_managed_assets_path())),
342+
blob_storage=None
343+
if not self._managed_assets_path.exists()
344+
else init_managed_blob_storage(str(self._managed_assets_path)),
343345
datastore=(
344346
LocalSeedDatasetDataStore()
345347
if (settings := config_builder.get_seed_datastore_settings()) is None
@@ -349,9 +351,3 @@ def _create_resource_provider(
349351
)
350352
),
351353
)
352-
353-
def _resolve_managed_assets_path(self) -> Path:
354-
if not self._managed_assets_path.exists():
355-
logger.info(f"🏗️ Creating your managed assets path at: {str(self._managed_assets_path)!r}.")
356-
self._managed_assets_path.mkdir(parents=True, exist_ok=True)
357-
return self._managed_assets_path

tests/engine/resources/test_managed_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_uri_for_key_normalization(stub_concrete_storage, test_key, expected_uri
4545
)
4646
def test_local_blob_storage_provider_init(test_case, root_path):
4747
provider = LocalBlobStorageProvider(root_path)
48-
assert provider._root_path == root_path
48+
assert provider.root_path == root_path
4949

5050

5151
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)