Skip to content

Commit 39e5f0b

Browse files
Update field names for classifiers (#1220)
* rename prob_column to score_field Signed-off-by: Sarah Yurick <[email protected]> * rename pred_column to label_field, raw_pred_column to raw_output_field, keep_raw_pred to keep_raw_output Signed-off-by: Sarah Yurick <[email protected]> * more name updates Signed-off-by: Sarah Yurick <[email protected]> * update embeddings as needed Signed-off-by: Sarah Yurick <[email protected]> * Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Sarah Yurick <[email protected]> * update outdated readme Signed-off-by: Sarah Yurick <[email protected]> --------- Signed-off-by: Sarah Yurick <[email protected]> Signed-off-by: Sarah Yurick <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 2d90b4b commit 39e5f0b

File tree

18 files changed

+276
-276
lines changed

18 files changed

+276
-276
lines changed

docs/curate-text/process-data/quality-assessment/classifier.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ The `QualityClassifier` accepts the following parameters:
339339
- `filter_by` (list, default=None): Quality levels to keep (options: "Low", "Medium", "High")
340340
- `model_inference_batch_size` (int, default=256): Batch size for inference
341341
- `max_chars` (int, default=6000): Max characters per document for processing
342-
- `pred_column` (str, default="quality_pred"): Name of the prediction column
342+
- `label_field` (str, default="quality_pred"): Name of the prediction column
343343
- `text_field` (str, default="text"): Name of the text field in input data
344344

345345
### FastTextQualityFilter
@@ -363,7 +363,7 @@ classifiers:
363363
filter_by: ["High"]
364364
model_inference_batch_size: 256
365365
max_chars: 6000
366-
pred_column: quality_pred
366+
label_field: quality_pred
367367
text_field: text
368368
```
369369

docs/curate-text/process-data/quality-assessment/distributed-classifier.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ NVIDIA NeMo Curator provides a base class `DistributedDataClassifier` that can b
4343
| MultilingualDomainClassifier | Categorize text in 52 languages by domain | [nvidia/multilingual-domain-classifier](https://huggingface.co/nvidia/multilingual-domain-classifier) | `filter_by`, `text_field` | None |
4444
| QualityClassifier | Assess document quality | [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) | `filter_by`, `text_field` | None |
4545
| AegisClassifier | Detect unsafe content | [nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0) | `aegis_variant`, `filter_by` | HuggingFace token |
46-
| InstructionDataGuardClassifier | Detect poisoning attacks | [nvidia/instruction-data-guard](https://huggingface.co/nvidia/instruction-data-guard) | `text_field`, `pred_column` | HuggingFace token |
47-
| FineWebEduClassifier | Score educational value | [HuggingFaceFW/fineweb-edu-classifier](https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier) | `pred_column`, `int_column` | None |
48-
| FineWebMixtralEduClassifier | Score educational value (Mixtral annotations) | [nvidia/nemocurator-fineweb-mixtral-edu-classifier](https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier) | `pred_column`, `int_column`, `model_inference_batch_size=1024` | None |
49-
| FineWebNemotronEduClassifier | Score educational value (Nemotron annotations) | [nvidia/nemocurator-fineweb-nemotron-4-edu-classifier](https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier) | `pred_column`, `int_column`, `model_inference_batch_size=1024` | None |
46+
| InstructionDataGuardClassifier | Detect poisoning attacks | [nvidia/instruction-data-guard](https://huggingface.co/nvidia/instruction-data-guard) | `text_field`, `label_field` | HuggingFace token |
47+
| FineWebEduClassifier | Score educational value | [HuggingFaceFW/fineweb-edu-classifier](https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier) | `label_field`, `int_field` | None |
48+
| FineWebMixtralEduClassifier | Score educational value (Mixtral annotations) | [nvidia/nemocurator-fineweb-mixtral-edu-classifier](https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier) | `label_field`, `int_field`, `model_inference_batch_size=1024` | None |
49+
| FineWebNemotronEduClassifier | Score educational value (Nemotron annotations) | [nvidia/nemocurator-fineweb-nemotron-4-edu-classifier](https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier) | `label_field`, `int_field`, `model_inference_batch_size=1024` | None |
5050
| ContentTypeClassifier | Categorize by speech type | [nvidia/content-type-classifier-deberta](https://huggingface.co/nvidia/content-type-classifier-deberta) | `filter_by`, `text_field` | None |
5151
| PromptTaskComplexityClassifier | Classify prompt tasks and complexity | [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier) | `text_field` | None |
5252

@@ -165,8 +165,8 @@ The classifier adds a column with labels: "safe," "O1" through "O13" (each repre
165165
safety_classifier = AegisClassifier(
166166
aegis_variant="nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
167167
hf_token=token,
168-
keep_raw_pred=True,
169-
raw_pred_column="raw_predictions"
168+
keep_raw_output=True,
169+
raw_output_field="raw_predictions"
170170
)
171171
```
172172

@@ -239,9 +239,9 @@ pipeline.add_stage(reader)
239239
# Apply the FineWeb Edu classifier
240240
edu_classifier = FineWebEduClassifier(
241241
model_inference_batch_size=256,
242-
float_score_column="fineweb-edu-score-float", # Raw float scores
243-
int_score_column="fineweb-edu-score-int", # Rounded integer scores
244-
pred_column="fineweb-edu-score-label" # Quality labels
242+
float_score_field="fineweb-edu-score-float", # Raw float scores
243+
int_score_field="fineweb-edu-score-int", # Rounded integer scores
244+
label_field="fineweb-edu-score-label" # Quality labels
245245
)
246246
pipeline.add_stage(edu_classifier)
247247

@@ -287,9 +287,9 @@ pipeline.add_stage(reader)
287287

288288
# Apply the FineWeb Mixtral Edu classifier
289289
classifier = FineWebMixtralEduClassifier(
290-
float_score_column="fineweb-mixtral-edu-score-float", # Raw float scores
291-
int_score_column="fineweb-mixtral-edu-score-int", # Rounded integer scores
292-
pred_column="fineweb-mixtral-edu-score-label" # "high_quality" or "low_quality"
290+
float_score_field="fineweb-mixtral-edu-score-float", # Raw float scores
291+
int_score_field="fineweb-mixtral-edu-score-int", # Rounded integer scores
292+
label_field="fineweb-mixtral-edu-score-label" # "high_quality" or "low_quality"
293293
)
294294
pipeline.add_stage(classifier)
295295

nemo_curator/stages/text/classifiers/aegis.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from nemo_curator.stages.base import CompositeStage, ProcessingStage
3232
from nemo_curator.stages.text.models.model import ModelStage
3333
from nemo_curator.stages.text.models.tokenizer import TokenizerStage
34-
from nemo_curator.stages.text.models.utils import ATTENTION_MASK_COLUMN, INPUT_ID_COLUMN, format_name_with_suffix
34+
from nemo_curator.stages.text.models.utils import ATTENTION_MASK_FIELD, INPUT_ID_FIELD, format_name_with_suffix
3535
from nemo_curator.stages.text.modules.score_filter import Filter
3636
from nemo_curator.tasks import DocumentBatch
3737

@@ -43,7 +43,7 @@
4343
"nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0",
4444
]
4545
INSTRUCTION_DATA_GUARD_MODEL_IDENTIFIER = "nvidia/instruction-data-guard"
46-
HIDDEN_TEXT_COLUMN = "_curator_hidden_text"
46+
HIDDEN_TEXT_FIELD = "_curator_hidden_text"
4747
MAX_SEQ_LENGTH = 4096
4848
TOKENIZER_PADDING_SIDE = "left"
4949
TORCH_DTYPE = torch.bfloat16
@@ -154,8 +154,8 @@ def __init__( # noqa: PLR0913
154154
model_identifier: str,
155155
cache_dir: str | None = None,
156156
hf_token: str | None = None,
157-
pred_column: str = "preds",
158-
prob_column: str = "probs",
157+
label_field: str = "preds",
158+
score_field: str = "probs",
159159
model_inference_batch_size: int = 256,
160160
has_seq_order: bool = True,
161161
add_instruction_data_guard: bool = False,
@@ -173,11 +173,11 @@ def __init__( # noqa: PLR0913
173173
)
174174

175175
self.add_instruction_data_guard = add_instruction_data_guard
176-
self.pred_column = pred_column
177-
self.prob_column = prob_column
176+
self.label_field = label_field
177+
self.score_field = score_field
178178

179179
def outputs(self) -> tuple[list[str], list[str]]:
180-
return ["data"], [self.pred_column] + ([self.prob_column] if self.add_instruction_data_guard else [])
180+
return ["data"], [self.label_field] + ([self.score_field] if self.add_instruction_data_guard else [])
181181

182182
# We use the _setup function to ensure that everything needed for Aegis is downloaded and loaded properly
183183
def _setup(self, local_files_only: bool = True) -> None:
@@ -214,17 +214,17 @@ def process_model_output(
214214
) -> dict[str, np.ndarray]:
215215
preds = outputs.cpu().numpy()
216216
return {
217-
self.pred_column: preds,
217+
self.label_field: preds,
218218
}
219219

220220
def create_output_dataframe(self, df_cpu: pd.DataFrame, collected_output: dict[str, np.ndarray]) -> pd.DataFrame:
221-
df_cpu = df_cpu.drop(columns=[INPUT_ID_COLUMN, ATTENTION_MASK_COLUMN])
221+
df_cpu = df_cpu.drop(columns=[INPUT_ID_FIELD, ATTENTION_MASK_FIELD])
222222

223223
if self.add_instruction_data_guard:
224-
df_cpu[self.prob_column] = collected_output[self.pred_column].tolist()
225-
df_cpu[self.pred_column] = (collected_output[self.pred_column] >= 0.5).tolist() # noqa: PLR2004
224+
df_cpu[self.score_field] = collected_output[self.label_field].tolist()
225+
df_cpu[self.label_field] = (collected_output[self.label_field] >= 0.5).tolist() # noqa: PLR2004
226226
else:
227-
df_cpu[self.pred_column] = collected_output[self.pred_column].tolist()
227+
df_cpu[self.label_field] = collected_output[self.label_field].tolist()
228228

229229
return df_cpu
230230

@@ -243,12 +243,12 @@ def inputs(self) -> tuple[list[str], list[str]]:
243243
return ["data"], [self.text_field]
244244

245245
def outputs(self) -> tuple[list[str], list[str]]:
246-
return ["data"], [HIDDEN_TEXT_COLUMN]
246+
return ["data"], [HIDDEN_TEXT_FIELD]
247247

248248
def _wrap_in_prompt(self, df: pd.DataFrame) -> pd.DataFrame:
249249
documents = df[self.text_field].tolist()
250250
prompts = [format_aegis(doc[: self.max_chars]) for doc in documents]
251-
df[HIDDEN_TEXT_COLUMN] = prompts
251+
df[HIDDEN_TEXT_FIELD] = prompts
252252
return df
253253

254254
def process(self, batch: DocumentBatch) -> DocumentBatch:
@@ -272,16 +272,16 @@ class PostProcessAegisResponsesStage(ProcessingStage[DocumentBatch, DocumentBatc
272272

273273
cache_dir: str | None = None
274274
hf_token: str | None = None
275-
pred_column: str = "aegis_pred"
276-
raw_pred_column: str = "_aegis_raw_pred"
277-
keep_raw_pred: bool = False
275+
label_field: str = "aegis_pred"
276+
raw_output_field: str = "_aegis_raw_pred"
277+
keep_raw_output: bool = False
278278
name = "postprocess_aegis_responses"
279279

280280
def inputs(self) -> tuple[list[str], list[str]]:
281-
return ["data"], [self.raw_pred_column, HIDDEN_TEXT_COLUMN]
281+
return ["data"], [self.raw_output_field, HIDDEN_TEXT_FIELD]
282282

283283
def outputs(self) -> tuple[list[str], list[str]]:
284-
return ["data"], [self.pred_column] + ([self.raw_pred_column] if self.keep_raw_pred else [])
284+
return ["data"], [self.label_field] + ([self.raw_output_field] if self.keep_raw_output else [])
285285

286286
def ray_stage_spec(self) -> dict[str, Any]:
287287
return {"is_actor_stage": True}
@@ -331,27 +331,27 @@ def _parse_response(self, raw_response: str) -> str:
331331
return "unknown"
332332

333333
def _postprocess_responses(self, df: pd.DataFrame) -> pd.DataFrame:
334-
generated_tokens = df[self.raw_pred_column].tolist()
334+
generated_tokens = df[self.raw_output_field].tolist()
335335

336336
generated_tokens = self.tokenizer.batch_decode(
337337
generated_tokens,
338338
skip_special_tokens=True,
339339
)
340340

341-
original_lengths = df[HIDDEN_TEXT_COLUMN].str.len().tolist()
341+
original_lengths = df[HIDDEN_TEXT_FIELD].str.len().tolist()
342342
generated_tokens = [
343343
chars[original_length:] for chars, original_length in zip(generated_tokens, original_lengths, strict=False)
344344
]
345345
parsed_response = [self._parse_response(response) for response in generated_tokens]
346346

347-
if self.keep_raw_pred:
348-
df[self.raw_pred_column] = pd.Series(generated_tokens)
347+
if self.keep_raw_output:
348+
df[self.raw_output_field] = pd.Series(generated_tokens)
349349
else:
350-
df = df.drop(columns=[self.raw_pred_column])
350+
df = df.drop(columns=[self.raw_output_field])
351351

352-
df[self.pred_column] = pd.Series(parsed_response)
352+
df[self.label_field] = pd.Series(parsed_response)
353353

354-
return df.drop(columns=[HIDDEN_TEXT_COLUMN])
354+
return df.drop(columns=[HIDDEN_TEXT_FIELD])
355355

356356
def process(self, batch: DocumentBatch) -> DocumentBatch:
357357
df = batch.to_pandas()
@@ -388,10 +388,10 @@ class AegisClassifier(CompositeStage[DocumentBatch, DocumentBatch]):
388388
hf_token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
389389
needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
390390
Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
391-
pred_column (str): The name of the column to store the resulting prediction. Defaults to "aegis_pred".
392-
raw_pred_column (str): The name of the column to store the raw output of the AEGIS LLM before
391+
label_field (str): The name of the column to store the resulting prediction. Defaults to "aegis_pred".
392+
raw_output_field (str): The name of the column to store the raw output of the AEGIS LLM before
393393
the prediction is extracted from it. Defaults to "_aegis_raw_pred".
394-
keep_raw_pred (bool): If True, will keep the unprocessed LLM output in raw_pred_column.
394+
keep_raw_output (bool): If True, will keep the unprocessed LLM output in raw_output_field.
395395
Useful for debugging when "unknown" shows up a lot in your dataset. Defaults to False.
396396
text_field (str): The field in the dataset that should be classified. Defaults to "text".
397397
filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
@@ -407,9 +407,9 @@ class AegisClassifier(CompositeStage[DocumentBatch, DocumentBatch]):
407407
aegis_variant: Literal[AEGIS_VARIANTS] = AEGIS_VARIANTS[0]
408408
cache_dir: str | None = None
409409
hf_token: str | bool | None = None
410-
pred_column: str = "aegis_pred"
411-
raw_pred_column: str = "_aegis_raw_pred"
412-
keep_raw_pred: bool = False
410+
label_field: str = "aegis_pred"
411+
raw_output_field: str = "_aegis_raw_pred"
412+
keep_raw_output: bool = False
413413
text_field: str = "text"
414414
filter_by: list[str] | None = None
415415
max_chars: int = 6000
@@ -431,7 +431,7 @@ def __post_init__(self) -> None:
431431
model_identifier=PRETRAINED_MODEL_NAME_OR_PATH,
432432
cache_dir=self.cache_dir,
433433
hf_token=self.hf_token,
434-
text_field=HIDDEN_TEXT_COLUMN,
434+
text_field=HIDDEN_TEXT_FIELD,
435435
max_seq_length=MAX_SEQ_LENGTH,
436436
padding_side=TOKENIZER_PADDING_SIDE,
437437
sort_by_length=self.sort_by_length,
@@ -441,7 +441,7 @@ def __post_init__(self) -> None:
441441
model_identifier=self.aegis_variant,
442442
cache_dir=self.cache_dir,
443443
hf_token=self.hf_token,
444-
pred_column=self.raw_pred_column,
444+
label_field=self.raw_output_field,
445445
model_inference_batch_size=self.model_inference_batch_size,
446446
has_seq_order=self.sort_by_length,
447447
add_instruction_data_guard=False,
@@ -450,14 +450,14 @@ def __post_init__(self) -> None:
450450
PostProcessAegisResponsesStage(
451451
cache_dir=self.cache_dir,
452452
hf_token=self.hf_token,
453-
pred_column=self.pred_column,
454-
raw_pred_column=self.raw_pred_column,
455-
keep_raw_pred=self.keep_raw_pred,
453+
label_field=self.label_field,
454+
raw_output_field=self.raw_output_field,
455+
keep_raw_output=self.keep_raw_output,
456456
),
457457
]
458458

459459
if self.filter_by is not None and len(self.filter_by) > 0:
460-
self.stages.append(Filter(filter_fn=self.filter_by_category, filter_field=self.pred_column))
460+
self.stages.append(Filter(filter_fn=self.filter_by_category, filter_field=self.label_field))
461461

462462
def inputs(self) -> tuple[list[str], list[str]]:
463463
return self.stages[0].inputs()
@@ -519,8 +519,8 @@ class InstructionDataGuardClassifier(CompositeStage[DocumentBatch, DocumentBatch
519519
hf_token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
520520
needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
521521
Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
522-
pred_column (str): The name of the column to store the resulting prediction. Defaults to "is_poisoned".
523-
prob_column (str): The name of the column to store the poisoning probability score. Defaults to "instruction_data_guard_poisoning_score".
522+
label_field (str): The name of the column to store the resulting prediction. Defaults to "is_poisoned".
523+
score_field (str): The name of the column to store the poisoning probability score. Defaults to "instruction_data_guard_poisoning_score".
524524
text_field (str): The field in the dataset that should be classified. Defaults to "text".
525525
filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
526526
expect those specified in this list. Defaults to None.
@@ -534,8 +534,8 @@ class InstructionDataGuardClassifier(CompositeStage[DocumentBatch, DocumentBatch
534534

535535
cache_dir: str | None = None
536536
hf_token: str | bool | None = None
537-
pred_column: str = "is_poisoned"
538-
prob_column: str = "instruction_data_guard_poisoning_score"
537+
label_field: str = "is_poisoned"
538+
score_field: str = "instruction_data_guard_poisoning_score"
539539
text_field: str = "text"
540540
filter_by: list[str] | None = None
541541
max_chars: int = 6000
@@ -564,8 +564,8 @@ def __post_init__(self) -> None:
564564
model_identifier=AEGIS_VARIANTS[0],
565565
cache_dir=self.cache_dir,
566566
hf_token=self.hf_token,
567-
pred_column=self.pred_column,
568-
prob_column=self.prob_column,
567+
label_field=self.label_field,
568+
score_field=self.score_field,
569569
model_inference_batch_size=self.model_inference_batch_size,
570570
has_seq_order=self.sort_by_length,
571571
add_instruction_data_guard=True,
@@ -574,7 +574,7 @@ def __post_init__(self) -> None:
574574
]
575575

576576
if self.filter_by is not None and len(self.filter_by) > 0:
577-
self.stages.append(Filter(filter_fn=self.filter_by_category, filter_field=self.pred_column))
577+
self.stages.append(Filter(filter_fn=self.filter_by_category, filter_field=self.label_field))
578578

579579
def inputs(self) -> tuple[list[str], list[str]]:
580580
return self.stages[0].inputs()

0 commit comments

Comments
 (0)