Skip to content

Commit 7cba674

Browse files
authored
chore: split person samplers and use parameters in sql exec (#48)
* split person samplers * pass catalogs * remove locale validation on dataset-based person sampler * use parameters in sql execution * not using that * update tests * add temp assets path * update error when too few samples exist * update error message name * add expected params to test
1 parent cb0b1c6 commit 7cba674

33 files changed

+310
-517
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ dependencies = [
3434
"rich>=13.7.1",
3535
"typer>=0.12.0",
3636
"anyascii>=0.3.3,<1.0",
37-
"boto3==1.35.74",
3837
"datasets>=4.0.0",
3938
"duckdb==1.1.3",
4039
"faker==20.1.0",
@@ -48,7 +47,6 @@ dependencies = [
4847
"networkx==3.0",
4948
"pydantic[email]>=2.9.2",
5049
"scipy>=1.11.0",
51-
"smart-open==7.0.5",
5250
"sqlfluff==3.2.0",
5351
"tiktoken>=0.8.0",
5452
"ruff==0.12.3",

src/data_designer/cli/commands/list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from data_designer.cli.repositories.model_repository import ModelRepository
77
from data_designer.cli.repositories.provider_repository import ProviderRepository
88
from data_designer.cli.ui import console, print_error, print_header, print_info, print_warning
9-
from data_designer.config.utils.constants import DATA_DESIGNER_HOME_DIR, NordColor
9+
from data_designer.config.utils.constants import DATA_DESIGNER_HOME, NordColor
1010

1111

1212
def list_command() -> None:
@@ -17,12 +17,12 @@ def list_command() -> None:
1717
"""
1818
# Determine config directory
1919
print_header("Data Designer Configurations")
20-
print_info(f"Configuration directory: {DATA_DESIGNER_HOME_DIR}")
20+
print_info(f"Configuration directory: {DATA_DESIGNER_HOME}")
2121
console.print()
2222

2323
# Display providers
24-
display_providers(ProviderRepository(DATA_DESIGNER_HOME_DIR))
25-
display_models(ModelRepository(DATA_DESIGNER_HOME_DIR))
24+
display_providers(ProviderRepository(DATA_DESIGNER_HOME))
25+
display_models(ModelRepository(DATA_DESIGNER_HOME))
2626

2727

2828
def display_providers(provider_repo: ProviderRepository) -> None:

src/data_designer/cli/commands/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from data_designer.cli.controllers.model_controller import ModelController
5-
from data_designer.config.utils.constants import DATA_DESIGNER_HOME_DIR
5+
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
66

77

88
def models_command() -> None:
9-
controller = ModelController(DATA_DESIGNER_HOME_DIR)
9+
controller = ModelController(DATA_DESIGNER_HOME)
1010
controller.run()

src/data_designer/cli/commands/providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from data_designer.cli.controllers.provider_controller import ProviderController
5-
from data_designer.config.utils.constants import DATA_DESIGNER_HOME_DIR
5+
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
66

77

88
def providers_command() -> None:
99
"""Configure model providers interactively."""
10-
controller = ProviderController(DATA_DESIGNER_HOME_DIR)
10+
controller = ProviderController(DATA_DESIGNER_HOME)
1111
controller.run()

src/data_designer/cli/commands/reset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@
1414
print_success,
1515
print_text,
1616
)
17-
from data_designer.config.utils.constants import DATA_DESIGNER_HOME_DIR
17+
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
1818

1919

2020
def reset_command() -> None:
2121
"""Reset configuration files by deleting them after confirmation."""
2222
print_header("Reset Configuration")
2323

2424
# Determine configuration directory
25-
print_info(f"Configuration directory: {DATA_DESIGNER_HOME_DIR}")
25+
print_info(f"Configuration directory: {DATA_DESIGNER_HOME}")
2626
console.print()
2727

2828
# Create repositories
29-
provider_repo = ProviderRepository(DATA_DESIGNER_HOME_DIR)
30-
model_repo = ModelRepository(DATA_DESIGNER_HOME_DIR)
29+
provider_repo = ProviderRepository(DATA_DESIGNER_HOME)
30+
model_repo = ModelRepository(DATA_DESIGNER_HOME)
3131

3232
# Check which config files exist
3333
provider_exists = provider_repo.exists()

src/data_designer/config/sampler_params.py

Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
LOCALES_WITH_MANAGED_DATASETS,
1616
MAX_AGE,
1717
MIN_AGE,
18-
US_STATES_AND_MAJOR_TERRITORIES,
1918
)
2019

2120

@@ -27,6 +26,7 @@ class SamplerType(str, Enum):
2726
DATETIME = "datetime"
2827
GAUSSIAN = "gaussian"
2928
PERSON = "person"
29+
PERSON_FROM_FAKER = "person_from_faker"
3030
POISSON = "poisson"
3131
SCIPY = "scipy"
3232
SUBCATEGORY = "subcategory"
@@ -219,8 +219,10 @@ class PersonSamplerParams(ConfigBase):
219219
locale: str = Field(
220220
default="en_US",
221221
description=(
222-
"Locale string, determines the language and geographic locale "
223-
"that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
222+
"Locale that determines the language and geographic location "
223+
"that a synthetic person will be sampled from. Must be a locale supported by "
224+
"a managed Nemotron Personas dataset. Managed datasets exist for the following locales: "
225+
f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
224226
),
225227
)
226228
sex: Optional[SexT] = Field(
@@ -237,36 +239,96 @@ class PersonSamplerParams(ConfigBase):
237239
min_length=2,
238240
max_length=2,
239241
)
240-
241-
state: Optional[Union[str, list[str]]] = Field(
242+
select_field_values: Optional[dict[str, list[str]]] = Field(
242243
default=None,
243244
description=(
244-
"Only supported for 'en_US' locale. If specified, then only synthetic people "
245-
"from these states will be sampled. States must be given as two-letter abbreviations."
245+
"Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
246+
"selecting a subset of the population from the managed dataset. Note that this sampler does not support "
247+
"rare combinations of field values and will likely fail if your desired subset is not well-represented "
248+
"in the managed Nemotron Personas dataset. We generally recommend using the `sex`, `city`, and `age_range` "
249+
"arguments to filter the population when possible."
246250
),
251+
examples=[
252+
{"state": ["NY", "CA", "OH", "TX", "NV"], "education_level": ["high_school", "some_college", "bachelors"]}
253+
],
247254
)
248255

249256
with_synthetic_personas: bool = Field(
250257
default=False,
251258
description="If True, then append synthetic persona columns to each generated person.",
252259
)
253260

254-
sample_dataset_when_available: bool = Field(
255-
default=True,
256-
description="If True, sample person data from managed dataset when available. Otherwise, use Faker.",
261+
@property
262+
def generator_kwargs(self) -> list[str]:
263+
"""Keyword arguments to pass to the person generator."""
264+
return [f for f in list(PersonSamplerParams.model_fields) if f != "locale"]
265+
266+
@property
267+
def people_gen_key(self) -> str:
268+
return f"{self.locale}_with_personas" if self.with_synthetic_personas else self.locale
269+
270+
@field_validator("age_range")
271+
@classmethod
272+
def _validate_age_range(cls, value: list[int]) -> list[int]:
273+
msg_prefix = "'age_range' must be a list of two integers, representing the min and max age."
274+
if value[0] < MIN_AGE:
275+
raise ValueError(
276+
f"{msg_prefix} The first integer (min age) must be greater than or equal to {MIN_AGE}, "
277+
f"but the first integer provided was {value[0]}."
278+
)
279+
if value[1] > MAX_AGE:
280+
raise ValueError(
281+
f"{msg_prefix} The second integer (max age) must be less than or equal to {MAX_AGE}, "
282+
f"but the second integer provided was {value[1]}."
283+
)
284+
if value[0] >= value[1]:
285+
raise ValueError(
286+
f"{msg_prefix} The first integer (min age) must be less than the second integer (max age), "
287+
f"but the first integer provided was {value[0]} and the second integer provided was {value[1]}."
288+
)
289+
return value
290+
291+
@model_validator(mode="after")
292+
def _validate_locale_with_managed_datasets(self) -> Self:
293+
if self.locale not in LOCALES_WITH_MANAGED_DATASETS:
294+
raise ValueError(
295+
"Person sampling from managed datasets is only supported for the following "
296+
f"locales: {', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
297+
)
298+
return self
299+
300+
301+
class PersonFromFakerSamplerParams(ConfigBase):
302+
locale: str = Field(
303+
default="en_US",
304+
description=(
305+
"Locale string, determines the language and geographic locale "
306+
"that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
307+
),
308+
)
309+
sex: Optional[SexT] = Field(
310+
default=None,
311+
description="If specified, then only synthetic people of the specified sex will be sampled.",
312+
)
313+
city: Optional[Union[str, list[str]]] = Field(
314+
default=None,
315+
description="If specified, then only synthetic people from these cities will be sampled.",
316+
)
317+
age_range: list[int] = Field(
318+
default=DEFAULT_AGE_RANGE,
319+
description="If specified, then only synthetic people within this age range will be sampled.",
320+
min_length=2,
321+
max_length=2,
257322
)
258323

259324
@property
260325
def generator_kwargs(self) -> list[str]:
261326
"""Keyword arguments to pass to the person generator."""
262-
return [f for f in list(PersonSamplerParams.model_fields) if f != "locale"]
327+
return [f for f in list(PersonFromFakerSamplerParams.model_fields) if f != "locale"]
263328

264329
@property
265330
def people_gen_key(self) -> str:
266-
if self.locale in LOCALES_WITH_MANAGED_DATASETS and self.sample_dataset_when_available:
267-
return f"{self.locale}_with_personas" if self.with_synthetic_personas else self.locale
268-
else:
269-
return f"{self.locale}_faker"
331+
return f"{self.locale}_faker"
270332

271333
@field_validator("age_range")
272334
@classmethod
@@ -298,35 +360,13 @@ def _validate_locale(cls, value: str) -> str:
298360
)
299361
return value
300362

301-
@model_validator(mode="after")
302-
def _validate_state(self) -> Self:
303-
if self.state is not None:
304-
orig_state_value = self.state
305-
if self.locale != "en_US":
306-
raise ValueError("'state' is only supported for 'en_US' locale.")
307-
if not isinstance(self.state, list):
308-
self.state = [self.state]
309-
self.state = [state.upper() for state in self.state]
310-
for state in self.state:
311-
if state not in US_STATES_AND_MAJOR_TERRITORIES:
312-
raise ValueError(f"State {orig_state_value!r} is not a supported state.")
313-
return self
314-
315-
@model_validator(mode="after")
316-
def _validate_with_synthetic_personas(self) -> Self:
317-
if self.with_synthetic_personas and self.locale not in LOCALES_WITH_MANAGED_DATASETS:
318-
raise ValueError(
319-
"'with_synthetic_personas' is only supported for the following "
320-
f"locales: {', '.join(LOCALES_WITH_MANAGED_DATASETS)}."
321-
)
322-
return self
323-
324363

325364
SamplerParamsT: TypeAlias = Union[
326365
SubcategorySamplerParams,
327366
CategorySamplerParams,
328367
DatetimeSamplerParams,
329368
PersonSamplerParams,
369+
PersonFromFakerSamplerParams,
330370
TimeDeltaSamplerParams,
331371
UUIDSamplerParams,
332372
BernoulliSamplerParams,

src/data_designer/config/utils/constants.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,21 @@ class NordColor(Enum):
260260
"zu_ZA",
261261
]
262262

263-
DATA_DESIGNER_HOME_DIR_ENV_VAR = "DATA_DESIGNER_HOME_DIR"
263+
DATA_DESIGNER_HOME_ENV_VAR = "DATA_DESIGNER_HOME"
264264

265-
DATA_DESIGNER_HOME_DIR = Path(os.getenv(DATA_DESIGNER_HOME_DIR_ENV_VAR, Path.home() / ".data-designer"))
265+
DATA_DESIGNER_HOME = Path(os.getenv(DATA_DESIGNER_HOME_ENV_VAR, Path.home() / ".data-designer"))
266+
267+
MANAGED_ASSETS_PATH_ENV_VAR = "DATA_DESIGNER_MANAGED_ASSETS_PATH"
268+
269+
MANAGED_ASSETS_PATH = Path(os.getenv(MANAGED_ASSETS_PATH_ENV_VAR, DATA_DESIGNER_HOME / "managed-assets"))
266270

267271
MODEL_CONFIGS_FILE_NAME = "model_configs.yaml"
268272

269-
MODEL_CONFIGS_FILE_PATH = DATA_DESIGNER_HOME_DIR / MODEL_CONFIGS_FILE_NAME
273+
MODEL_CONFIGS_FILE_PATH = DATA_DESIGNER_HOME / MODEL_CONFIGS_FILE_NAME
270274

271275
MODEL_PROVIDERS_FILE_NAME = "model_providers.yaml"
272276

273-
MODEL_PROVIDERS_FILE_PATH = DATA_DESIGNER_HOME_DIR / MODEL_PROVIDERS_FILE_NAME
277+
MODEL_PROVIDERS_FILE_PATH = DATA_DESIGNER_HOME / MODEL_PROVIDERS_FILE_NAME
274278

275279
NVIDIA_PROVIDER_NAME = "nvidia"
276280

src/data_designer/engine/resources/managed_dataset_generator.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,19 @@ def generate_samples(
1717
self,
1818
size: int = 1,
1919
evidence: dict[str, Any | list[Any]] = {},
20-
seed: int | None = None,
2120
) -> pd.DataFrame:
21+
parameters = []
2222
query = f"select * from {self.dataset_name}"
23-
# Build the WHERE clause if there are filters
24-
# NOTE: seed is not used because it's not straightforward
25-
# to make randomization both fast and repeatable
2623
if evidence:
2724
where_conditions = []
2825
for column, values in evidence.items():
2926
if values:
3027
values = values if isinstance(values, list) else [values]
31-
formatted_values = [f"'{val}'" for val in values]
28+
formatted_values = ["?"] * len(values)
3229
condition = f"{column} IN ({', '.join(formatted_values)})"
3330
where_conditions.append(condition)
31+
parameters.extend(values)
3432
if where_conditions:
3533
query += " where " + " and ".join(where_conditions)
3634
query += f" order by random() limit {size}"
37-
return self.managed_datasets.query(query)
35+
return self.managed_datasets.query(query, parameters)

src/data_designer/engine/resources/managed_dataset_repository.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tempfile
1010
import threading
1111
import time
12+
from typing import Any
1213

1314
import duckdb
1415
import pandas as pd
@@ -60,7 +61,7 @@ def name(self) -> str:
6061

6162
class ManagedDatasetRepository(ABC):
6263
@abstractmethod
63-
def query(self, sql: str) -> pd.DataFrame: ...
64+
def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame: ...
6465

6566
@property
6667
@abstractmethod
@@ -129,7 +130,7 @@ def _register_datasets(self):
129130
for table in self.data_catalog:
130131
key = table.source if table.schema == "main" else f"{table.schema}/{table.source}"
131132
if self._use_cache:
132-
tmp_root = Path(tempfile.gettempdir()) / "gretel_ds_cache"
133+
tmp_root = Path(tempfile.gettempdir()) / "dd_cache"
133134
local_path = tmp_root / key
134135
local_path.parent.mkdir(parents=True, exist_ok=True)
135136
if not local_path.exists():
@@ -160,7 +161,7 @@ def _register_datasets(self):
160161
# Signal that registration is complete so any waiting queries can proceed.
161162
self._registration_event.set()
162163

163-
def query(self, sql: str) -> pd.DataFrame:
164+
def query(self, sql: str, parameters: list[Any]) -> pd.DataFrame:
164165
# Ensure dataset registration has completed. Possible future optimization:
165166
# pull datasets in parallel and only wait here if the query requires a
166167
# table that isn't cached.
@@ -173,7 +174,7 @@ def query(self, sql: str) -> pd.DataFrame:
173174
# more details here: https://duckdb.org/docs/stable/guides/python/multiple_threads.html
174175
cursor = self.db.cursor()
175176
try:
176-
df = cursor.sql(sql).df()
177+
df = cursor.execute(sql, parameters).df()
177178
finally:
178179
cursor.close()
179180
return df
@@ -183,10 +184,11 @@ def data_catalog(self) -> DataCatalog:
183184
return self._data_catalog
184185

185186

186-
def load_managed_dataset_repository(blob_storage: ManagedBlobStorage) -> ManagedDatasetRepository:
187+
def load_managed_dataset_repository(blob_storage: ManagedBlobStorage, locales: list[str]) -> ManagedDatasetRepository:
187188
return DuckDBDatasetRepository(
188189
blob_storage,
189-
{"threads": 1, "memory_limit": "2 gb"},
190+
config={"threads": 1, "memory_limit": "2 gb"},
191+
data_catalog=[Table(f"{locale}.parquet") for locale in locales],
190192
# Only cache if not using local storage.
191193
use_cache=not isinstance(blob_storage, LocalBlobStorageProvider),
192194
)

0 commit comments

Comments
 (0)