1515 LOCALES_WITH_MANAGED_DATASETS ,
1616 MAX_AGE ,
1717 MIN_AGE ,
18- US_STATES_AND_MAJOR_TERRITORIES ,
1918)
2019
2120
@@ -27,6 +26,7 @@ class SamplerType(str, Enum):
2726 DATETIME = "datetime"
2827 GAUSSIAN = "gaussian"
2928 PERSON = "person"
29+ PERSON_FROM_FAKER = "person_from_faker"
3030 POISSON = "poisson"
3131 SCIPY = "scipy"
3232 SUBCATEGORY = "subcategory"
@@ -219,8 +219,10 @@ class PersonSamplerParams(ConfigBase):
219219 locale : str = Field (
220220 default = "en_US" ,
221221 description = (
222- "Locale string, determines the language and geographic locale "
223- "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
222+ "Locale that determines the language and geographic location "
223+ "that a synthetic person will be sampled from. Must be a locale supported by "
224+ "a managed Nemotron Personas dataset. Managed datasets exist for the following locales: "
225+ f"{ ', ' .join (LOCALES_WITH_MANAGED_DATASETS )} ."
224226 ),
225227 )
226228 sex : Optional [SexT ] = Field (
@@ -237,36 +239,96 @@ class PersonSamplerParams(ConfigBase):
237239 min_length = 2 ,
238240 max_length = 2 ,
239241 )
240-
241- state : Optional [Union [str , list [str ]]] = Field (
242+ select_field_values : Optional [dict [str , list [str ]]] = Field (
242243 default = None ,
243244 description = (
244- "Only supported for 'en_US' locale. If specified, then only synthetic people "
245- "from these states will be sampled. States must be given as two-letter abbreviations."
245+ "Sample synthetic people with the specified field values. This is meant to be a flexible argument for "
246+ "selecting a subset of the population from the managed dataset. Note that this sampler does not support "
247+ "rare combinations of field values and will likely fail if your desired subset is not well-represented "
248+ "in the managed Nemotron Personas dataset. We generally recommend using the `sex`, `city`, and `age_range` "
249+ "arguments to filter the population when possible."
246250 ),
251+ examples = [
252+ {"state" : ["NY" , "CA" , "OH" , "TX" , "NV" ], "education_level" : ["high_school" , "some_college" , "bachelors" ]}
253+ ],
247254 )
248255
249256 with_synthetic_personas : bool = Field (
250257 default = False ,
251258 description = "If True, then append synthetic persona columns to each generated person." ,
252259 )
253260
254- sample_dataset_when_available : bool = Field (
255- default = True ,
256- description = "If True, sample person data from managed dataset when available. Otherwise, use Faker." ,
261+ @property
262+ def generator_kwargs (self ) -> list [str ]:
263+ """Keyword arguments to pass to the person generator."""
264+ return [f for f in list (PersonSamplerParams .model_fields ) if f != "locale" ]
265+
266+ @property
267+ def people_gen_key (self ) -> str :
268+ return f"{ self .locale } _with_personas" if self .with_synthetic_personas else self .locale
269+
270+ @field_validator ("age_range" )
271+ @classmethod
272+ def _validate_age_range (cls , value : list [int ]) -> list [int ]:
273+ msg_prefix = "'age_range' must be a list of two integers, representing the min and max age."
274+ if value [0 ] < MIN_AGE :
275+ raise ValueError (
276+ f"{ msg_prefix } The first integer (min age) must be greater than or equal to { MIN_AGE } , "
277+ f"but the first integer provided was { value [0 ]} ."
278+ )
279+ if value [1 ] > MAX_AGE :
280+ raise ValueError (
281+ f"{ msg_prefix } The second integer (max age) must be less than or equal to { MAX_AGE } , "
282+ f"but the second integer provided was { value [1 ]} ."
283+ )
284+ if value [0 ] >= value [1 ]:
285+ raise ValueError (
286+ f"{ msg_prefix } The first integer (min age) must be less than the second integer (max age), "
287+ f"but the first integer provided was { value [0 ]} and the second integer provided was { value [1 ]} ."
288+ )
289+ return value
290+
291+ @model_validator (mode = "after" )
292+ def _validate_locale_with_managed_datasets (self ) -> Self :
293+ if self .locale not in LOCALES_WITH_MANAGED_DATASETS :
294+ raise ValueError (
295+ "Person sampling from managed datasets is only supported for the following "
296+ f"locales: { ', ' .join (LOCALES_WITH_MANAGED_DATASETS )} ."
297+ )
298+ return self
299+
300+
301+ class PersonFromFakerSamplerParams (ConfigBase ):
302+ locale : str = Field (
303+ default = "en_US" ,
304+ description = (
305+ "Locale string, determines the language and geographic locale "
306+ "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
307+ ),
308+ )
309+ sex : Optional [SexT ] = Field (
310+ default = None ,
311+ description = "If specified, then only synthetic people of the specified sex will be sampled." ,
312+ )
313+ city : Optional [Union [str , list [str ]]] = Field (
314+ default = None ,
315+ description = "If specified, then only synthetic people from these cities will be sampled." ,
316+ )
317+ age_range : list [int ] = Field (
318+ default = DEFAULT_AGE_RANGE ,
319+ description = "If specified, then only synthetic people within this age range will be sampled." ,
320+ min_length = 2 ,
321+ max_length = 2 ,
257322 )
258323
259324 @property
260325 def generator_kwargs (self ) -> list [str ]:
261326 """Keyword arguments to pass to the person generator."""
262- return [f for f in list (PersonSamplerParams .model_fields ) if f != "locale" ]
327+ return [f for f in list (PersonFromFakerSamplerParams .model_fields ) if f != "locale" ]
263328
264329 @property
265330 def people_gen_key (self ) -> str :
266- if self .locale in LOCALES_WITH_MANAGED_DATASETS and self .sample_dataset_when_available :
267- return f"{ self .locale } _with_personas" if self .with_synthetic_personas else self .locale
268- else :
269- return f"{ self .locale } _faker"
331+ return f"{ self .locale } _faker"
270332
271333 @field_validator ("age_range" )
272334 @classmethod
@@ -298,35 +360,13 @@ def _validate_locale(cls, value: str) -> str:
298360 )
299361 return value
300362
301- @model_validator (mode = "after" )
302- def _validate_state (self ) -> Self :
303- if self .state is not None :
304- orig_state_value = self .state
305- if self .locale != "en_US" :
306- raise ValueError ("'state' is only supported for 'en_US' locale." )
307- if not isinstance (self .state , list ):
308- self .state = [self .state ]
309- self .state = [state .upper () for state in self .state ]
310- for state in self .state :
311- if state not in US_STATES_AND_MAJOR_TERRITORIES :
312- raise ValueError (f"State { orig_state_value !r} is not a supported state." )
313- return self
314-
315- @model_validator (mode = "after" )
316- def _validate_with_synthetic_personas (self ) -> Self :
317- if self .with_synthetic_personas and self .locale not in LOCALES_WITH_MANAGED_DATASETS :
318- raise ValueError (
319- "'with_synthetic_personas' is only supported for the following "
320- f"locales: { ', ' .join (LOCALES_WITH_MANAGED_DATASETS )} ."
321- )
322- return self
323-
324363
325364SamplerParamsT : TypeAlias = Union [
326365 SubcategorySamplerParams ,
327366 CategorySamplerParams ,
328367 DatetimeSamplerParams ,
329368 PersonSamplerParams ,
369+ PersonFromFakerSamplerParams ,
330370 TimeDeltaSamplerParams ,
331371 UUIDSamplerParams ,
332372 BernoulliSamplerParams ,
0 commit comments