Skip to content

Commit b15e2ba

Browse files
committed
scaffold
1 parent fdbc012 commit b15e2ba

File tree

7 files changed

+319
-2
lines changed

7 files changed

+319
-2
lines changed

examples/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
artifacts

examples/example.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
from data_designer.essentials import (
2+
DataDesignerConfigBuilder,
3+
ModelConfig,
4+
InferenceParameters,
5+
SamplerColumnConfig,
6+
CategorySamplerParams,
7+
SubcategorySamplerParams,
8+
PersonSamplerParams,
9+
LLMTextColumnConfig,
10+
Score,
11+
DataDesigner,
12+
ToJsonlProcessorConfig,
13+
)
14+
15+
# define model aliases
16+
model_alias_generator = "content_generator"
17+
model_configs = [
18+
ModelConfig(
19+
alias=model_alias_generator,
20+
provider="nvidia",
21+
model="deepseek-ai/deepseek-r1-distill-qwen-14b",
22+
inference_parameters=InferenceParameters(
23+
max_tokens=8000,
24+
temperature=0.7,
25+
top_p=0.95,
26+
),
27+
)
28+
]
29+
30+
config_builder = DataDesignerConfigBuilder(model_configs=model_configs)
31+
32+
# ESI levels
33+
ESI_LEVELS = [
34+
"ESI 1: Resuscitation",
35+
"ESI 2: Emergency",
36+
"ESI 3: Urgent",
37+
"ESI 4: Less Urgent",
38+
"ESI 5: Non-urgent",
39+
]
40+
41+
# Unique record ID
42+
config_builder.add_column(
43+
name="record_id",
44+
column_type="sampler",
45+
sampler_type="uuid",
46+
params={"short_form": True, "uppercase": True}
47+
)
48+
49+
# ESI level (balanced sampling)
50+
config_builder.add_column(
51+
SamplerColumnConfig(
52+
name="esi_level_description",
53+
sampler_type="category",
54+
params=CategorySamplerParams(
55+
values=ESI_LEVELS,
56+
),
57+
)
58+
)
59+
60+
# Clinical scenario (conditioned on ESI level)
61+
config_builder.add_column(
62+
SamplerColumnConfig(
63+
name="clinical_scenario",
64+
sampler_type="subcategory",
65+
params=SubcategorySamplerParams(
66+
category="esi_level_description",
67+
values={
68+
ESI_LEVELS[0]: [
69+
"Cardiac arrest",
70+
"Unresponsive with no pulse",
71+
"Severe respiratory distress",
72+
"Major trauma with signs of shock",
73+
"Suspected narcotic overdose with shallow respirations",
74+
],
75+
ESI_LEVELS[1]: [
76+
"Crushing substernal chest pain radiating to the left arm",
77+
"Sudden onset of facial droop and arm weakness",
78+
"New onset confusion in an elderly patient",
79+
"Active suicidal ideation with a plan",
80+
"High-speed motor vehicle accident",
81+
"Severe abdominal pain in a patient with a history of aortic aneurysm",
82+
],
83+
ESI_LEVELS[2]: [
84+
"Abdominal pain with fever and nausea",
85+
"High fever with a productive cough and history of COPD",
86+
"Displaced fracture with visible deformity",
87+
"Asthma attack, responsive to initial treatment",
88+
"Vaginal bleeding in a pregnant patient",
89+
"Head injury with brief loss of consciousness",
90+
],
91+
ESI_LEVELS[3]: [
92+
"Simple laceration requiring sutures",
93+
"Twisted ankle, unable to bear weight",
94+
"Sore throat with fever",
95+
"Symptoms of a urinary tract infection",
96+
"Painful ear with fever in a child",
97+
],
98+
ESI_LEVELS[4]: [
99+
"Request for a prescription refill",
100+
"Suture removal",
101+
"Minor rash present for several days",
102+
"Common cold symptoms",
103+
"Follow-up for a minor wound check",
104+
],
105+
},
106+
),
107+
)
108+
)
109+
110+
# Synthetic patient info
111+
config_builder.add_column(
112+
SamplerColumnConfig(
113+
name="patient",
114+
sampler_type="person",
115+
params=PersonSamplerParams(age_range=[18, 70]),
116+
)
117+
)
118+
119+
# Triage note writing style (captures range from poor to best quality notes)
120+
config_builder.add_column(
121+
SamplerColumnConfig(
122+
name="writing_style",
123+
sampler_type="category",
124+
params=CategorySamplerParams(
125+
values=["Draft", "Adequate", "Polished"]
126+
),
127+
)
128+
)
129+
130+
# LLM-generated triage note
131+
config_builder.add_column(
132+
LLMTextColumnConfig(
133+
name="content",
134+
prompt=(
135+
"You are an experienced triage nurse in a busy Emergency Department writing a draft note. "
136+
"Write a realistic, concise triage note in a telegraphic style using common medical abbreviations. "
137+
"The note is for a {{ patient.age }} y/o {{ 'M' if patient.sex == 'Male' else 'F' }}. "
138+
"Triage classification: '{{ esi_level_description }}'. "
139+
"Reason for visit: '{{ clinical_scenario }}'. "
140+
"Desired writing style: '{{ writing_style }}'. "
141+
"Structure the note with 'CC:' and 'HPI:'. "
142+
"Adjust the style and level of clinical detail based on the 'writing_style': "
143+
"- Draft: Use minimal structure, brief statements, and omit some details; clinical indicators may be less clear. "
144+
"- Adequate: Use complete sentences, include all relevant clinical indicators, but avoid excessive detail. "
145+
"- Polished: Be thorough, precise, and clear; include nuanced or subtle signs and show strong clinical reasoning. "
146+
"Also, adjust level of detail based on urgency (ESI 1 is always brief). "
147+
"Respond with ONLY the note text, starting with 'CC:'."
148+
),
149+
model_alias=model_alias_generator,
150+
)
151+
)
152+
153+
# Rubric: clinical coherence
154+
clinical_coherence_rubric = Score(
155+
name="Clinical Coherence",
156+
description="Evaluates how well the clinical details in the triage note align with the assigned ESI level and scenario.",
157+
options={
158+
"5": "Note is perfectly aligned with the ESI level and scenario; details are clinically plausible and specific.",
159+
"4": "Note is well-aligned, with only minor details that might be slightly inconsistent.",
160+
"3": "Note is generally consistent, but some key clinical indicators are missing or don't fully match the ESI level.",
161+
"2": "Note shows significant inconsistency between the clinical details and the assigned ESI level.",
162+
"1": "Note is clinically incoherent and does not reflect the assigned ESI level or scenario at all."
163+
}
164+
)
165+
166+
# Rubric: ESI level complexity (reduced to 3 levels: Simple, Moderate, Complex)
167+
esi_level_complexity_rubric = Score(
168+
name="ESI Level Complexity",
169+
description="Evaluates how difficult it is to infer the correct ESI level from the note. Higher scores indicate greater complexity, which is desirable for creating a challenging dataset.",
170+
options={
171+
"Complex": "Note contains subtle or conflicting information, requiring clinical reasoning to distinguish between ESI levels.",
172+
"Moderate": "Note requires some clinical inference; indicators are present but not always immediately obvious.",
173+
"Simple": "Note uses clear, direct, or textbook indicators that make the ESI level obvious."
174+
}
175+
)
176+
177+
jsonl_entry_template = {
178+
"messages": [
179+
{
180+
"role": "system",
181+
"content": (
182+
"You are an expert ER triage nurse. Your task is to classify the following triage note into one of the five Emergency Severity Index (ESI) levels."
183+
f" The possible levels are: {', '.join([repr(level) for level in ESI_LEVELS])}."
184+
" Carefully analyze the clinical details in the triage note, focusing on patient acuity, resource needs, and risk of rapid deterioration."
185+
" Respond with only the selected ESI level description, exactly matching one of the listed possibilities. Do not provide extra text or explanation."
186+
)
187+
},
188+
{
189+
"role": "user",
190+
"content": (
191+
"Triage Note: {{ content }}\n"
192+
"Classify the ESI level for this note based on the provided definitions."
193+
" Respond in JSON format only: { \"esi_level_description\": \"...\" }"
194+
)
195+
},
196+
{
197+
"role": "assistant",
198+
"content": (
199+
'{ "esi_level_description": "{{ esi_level_description }}" }'
200+
)
201+
},
202+
],
203+
}
204+
205+
config_builder.add_processor(
206+
ToJsonlProcessorConfig(
207+
template=jsonl_entry_template,
208+
folder_name="jsonl_files",
209+
fraction_per_file={
210+
"train.jsonl": 0.8,
211+
"validation.jsonl": 0.2,
212+
},
213+
)
214+
)
215+
216+
dd = DataDesigner(artifact_path="./artifacts", blob_storage_path="/Users/amanoel/Data/nemotron-personas-datasets_v0.0.6")
217+
preview = dd.preview(config_builder, num_records=10)

src/data_designer/config/processors.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616
class ProcessorType(str, Enum):
1717
DROP_COLUMNS = "drop_columns"
18+
TO_JSONL = "to_jsonl"
1819

1920

2021
class ProcessorConfig(ConfigBase, ABC):
2122
build_stage: BuildStage = Field(
22-
..., description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}"
23+
default=BuildStage.POST_BATCH,
24+
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}"
2325
)
2426

2527
@field_validator("build_stage")
@@ -34,8 +36,26 @@ def validate_build_stage(cls, v: BuildStage) -> BuildStage:
3436
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
3537
if processor_type == ProcessorType.DROP_COLUMNS:
3638
return DropColumnsProcessorConfig(**kwargs)
39+
elif processor_type == ProcessorType.TO_JSONL:
40+
return ToJsonlProcessorConfig(**kwargs)
3741

3842

3943
class DropColumnsProcessorConfig(ProcessorConfig):
4044
column_names: list[str]
4145
processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
46+
47+
48+
class ToJsonlProcessorConfig(ProcessorConfig):
49+
template: dict = Field(..., description="The template to use for each entry in the dataset.")
50+
folder_name: str = Field(..., description="Folder where JSONL files will be saved.")
51+
fraction_per_file: dict[str, float] = Field(
52+
default={"train.jsonl": 0.8, "validation.jsonl": 0.2},
53+
description="Fraction of the dataset to save in each file. The keys are the filenames and the values are the fractions.",
54+
)
55+
processor_type: Literal[ProcessorType.TO_JSONL] = ProcessorType.TO_JSONL
56+
57+
@field_validator("fraction_per_file")
58+
def validate_fraction_per_file(cls, v: dict[str, float]) -> dict[str, float]:
59+
if sum(v.values()) != 1:
60+
raise ValueError("The fractions must sum to 1.")
61+
return v

src/data_designer/engine/dataset_builders/artifact_storage.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ArtifactStorage(BaseModel):
3232
final_dataset_folder_name: str = "parquet-files"
3333
partial_results_folder_name: str = "tmp-partial-parquet-files"
3434
dropped_columns_folder_name: str = "dropped-columns-parquet-files"
35+
outputs_folder_name: str = "outputs"
3536

3637
@property
3738
def artifact_path_exists(self) -> bool:
@@ -57,6 +58,10 @@ def metadata_file_path(self) -> Path:
5758
def partial_results_path(self) -> Path:
5859
return self.base_dataset_path / self.partial_results_folder_name
5960

61+
@property
62+
def outputs_path(self) -> Path:
63+
return self.base_dataset_path / self.outputs_folder_name
64+
6065
@field_validator("artifact_path")
6166
def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
6267
v = Path(v)
@@ -178,5 +183,10 @@ def write_metadata(self, metadata: dict) -> Path:
178183
json.dump(metadata, file)
179184
return self.metadata_file_path
180185

186+
def move_to_outputs(self, from_path: Path, to_folder_name: str) -> Path:
187+
self.mkdir_if_needed(self.outputs_path / to_folder_name)
188+
shutil.move(from_path, self.outputs_path / to_folder_name / from_path.name)
189+
return self.outputs_path / to_folder_name / from_path.name
190+
181191
def _get_stage_path(self, stage: BatchStage) -> Path:
182192
return getattr(self, resolve_string_enum(stage, BatchStage).value)

src/data_designer/engine/processing/processors/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from data_designer.config.processors import (
66
DropColumnsProcessorConfig,
77
ProcessorType,
8+
ToJsonlProcessorConfig,
89
)
910
from data_designer.engine.processing.processors.base import Processor
1011
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
12+
from data_designer.engine.processing.processors.to_jsonl import ToJsonlProcessor
1113
from data_designer.engine.registry.base import TaskRegistry
1214

1315

@@ -17,4 +19,5 @@ class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
1719
def create_default_processor_registry() -> ProcessorRegistry:
1820
registry = ProcessorRegistry()
1921
registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
22+
registry.register(ProcessorType.TO_JSONL, ToJsonlProcessor, ToJsonlProcessorConfig, False)
2023
return registry
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import logging
6+
import tempfile
7+
8+
import pandas as pd
9+
from pathlib import Path
10+
11+
from data_designer.config.processors import ToJsonlProcessorConfig
12+
from data_designer.engine.configurable_task import ConfigurableTaskMetadata
13+
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
14+
from data_designer.engine.processing.processors.base import Processor
15+
from data_designer.engine.processing.utils import deserialize_json_values
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class ToJsonlProcessor(WithJinja2UserTemplateRendering, Processor[ToJsonlProcessorConfig]):
21+
@staticmethod
22+
def metadata() -> ConfigurableTaskMetadata:
23+
return ConfigurableTaskMetadata(
24+
name="to_jsonl",
25+
description="Save formatted dataset as JSONL files.",
26+
required_resources=None,
27+
)
28+
29+
@property
30+
def template_as_string(self) -> str:
31+
return json.dumps(self.config.template)
32+
33+
def _get_stop_index_per_file(self, dataset_size: int) -> dict[str, int]:
34+
"""Helper function to get the end index for each file of the split."""
35+
stop_index_per_file = {}
36+
37+
accumulated_fraction = 0.0
38+
for filename, fraction in self.config.fraction_per_file.items():
39+
accumulated_fraction += fraction
40+
stop_index_per_file[filename] = min(int(accumulated_fraction * dataset_size), dataset_size)
41+
42+
return stop_index_per_file
43+
44+
def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
45+
self.prepare_jinja2_template_renderer(self.template_as_string, data.columns.to_list())
46+
47+
stop_index_per_file = self._get_stop_index_per_file(len(data))
48+
with tempfile.TemporaryDirectory() as temp_dir:
49+
start_index = 0
50+
for filename, stop_index in stop_index_per_file.items():
51+
logger.info(f"✏️ Writing {stop_index - start_index} formatted JSONL entries to {filename}")
52+
53+
records = data.iloc[start_index:stop_index].to_dict(orient="records")
54+
with open(Path(temp_dir) / f"{filename}", "a") as f:
55+
for i, record in enumerate(records):
56+
rendered_jsonl_entry = self.render_template(deserialize_json_values(record))
57+
escaped_jsonl_entry = rendered_jsonl_entry.replace("\n", "\\n")
58+
f.write(escaped_jsonl_entry)
59+
if i < len(records) - 1:
60+
f.write("\n")
61+
start_index = stop_index
62+
63+
self.artifact_storage.move_to_outputs(Path(temp_dir) / filename, self.config.folder_name)
64+
65+
return data

src/data_designer/essentials/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
UniformDistribution,
3535
UniformDistributionParams,
3636
)
37-
from ..config.processors import DropColumnsProcessorConfig, ProcessorType
37+
from ..config.processors import DropColumnsProcessorConfig, ProcessorType, ToJsonlProcessorConfig
3838
from ..config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint
3939
from ..config.sampler_params import (
4040
BernoulliMixtureSamplerParams,
@@ -124,6 +124,7 @@
124124
"SeedDatasetColumnConfig",
125125
"SubcategorySamplerParams",
126126
"TimeDeltaSamplerParams",
127+
"ToJsonlProcessorConfig",
127128
"UniformDistribution",
128129
"UniformDistributionParams",
129130
"UniformSamplerParams",

0 commit comments

Comments
 (0)