3131from nemo_curator .stages .base import CompositeStage , ProcessingStage
3232from nemo_curator .stages .text .models .model import ModelStage
3333from nemo_curator .stages .text .models .tokenizer import TokenizerStage
34- from nemo_curator .stages .text .models .utils import ATTENTION_MASK_COLUMN , INPUT_ID_COLUMN , format_name_with_suffix
34+ from nemo_curator .stages .text .models .utils import ATTENTION_MASK_FIELD , INPUT_ID_FIELD , format_name_with_suffix
3535from nemo_curator .stages .text .modules .score_filter import Filter
3636from nemo_curator .tasks import DocumentBatch
3737
4343 "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0" ,
4444]
4545INSTRUCTION_DATA_GUARD_MODEL_IDENTIFIER = "nvidia/instruction-data-guard"
46- HIDDEN_TEXT_COLUMN = "_curator_hidden_text"
46+ HIDDEN_TEXT_FIELD = "_curator_hidden_text"
4747MAX_SEQ_LENGTH = 4096
4848TOKENIZER_PADDING_SIDE = "left"
4949TORCH_DTYPE = torch .bfloat16
@@ -154,8 +154,8 @@ def __init__( # noqa: PLR0913
154154 model_identifier : str ,
155155 cache_dir : str | None = None ,
156156 hf_token : str | None = None ,
157- pred_column : str = "preds" ,
158- prob_column : str = "probs" ,
157+ label_field : str = "preds" ,
158+ score_field : str = "probs" ,
159159 model_inference_batch_size : int = 256 ,
160160 has_seq_order : bool = True ,
161161 add_instruction_data_guard : bool = False ,
@@ -173,11 +173,11 @@ def __init__( # noqa: PLR0913
173173 )
174174
175175 self .add_instruction_data_guard = add_instruction_data_guard
176- self .pred_column = pred_column
177- self .prob_column = prob_column
176+ self .label_field = label_field
177+ self .score_field = score_field
178178
179179 def outputs (self ) -> tuple [list [str ], list [str ]]:
180- return ["data" ], [self .pred_column ] + ([self .prob_column ] if self .add_instruction_data_guard else [])
180+ return ["data" ], [self .label_field ] + ([self .score_field ] if self .add_instruction_data_guard else [])
181181
182182 # We use the _setup function to ensure that everything needed for Aegis is downloaded and loaded properly
183183 def _setup (self , local_files_only : bool = True ) -> None :
@@ -214,17 +214,17 @@ def process_model_output(
214214 ) -> dict [str , np .ndarray ]:
215215 preds = outputs .cpu ().numpy ()
216216 return {
217- self .pred_column : preds ,
217+ self .label_field : preds ,
218218 }
219219
220220 def create_output_dataframe (self , df_cpu : pd .DataFrame , collected_output : dict [str , np .ndarray ]) -> pd .DataFrame :
221- df_cpu = df_cpu .drop (columns = [INPUT_ID_COLUMN , ATTENTION_MASK_COLUMN ])
221+ df_cpu = df_cpu .drop (columns = [INPUT_ID_FIELD , ATTENTION_MASK_FIELD ])
222222
223223 if self .add_instruction_data_guard :
224- df_cpu [self .prob_column ] = collected_output [self .pred_column ].tolist ()
225- df_cpu [self .pred_column ] = (collected_output [self .pred_column ] >= 0.5 ).tolist () # noqa: PLR2004
224+ df_cpu [self .score_field ] = collected_output [self .label_field ].tolist ()
225+ df_cpu [self .label_field ] = (collected_output [self .label_field ] >= 0.5 ).tolist () # noqa: PLR2004
226226 else :
227- df_cpu [self .pred_column ] = collected_output [self .pred_column ].tolist ()
227+ df_cpu [self .label_field ] = collected_output [self .label_field ].tolist ()
228228
229229 return df_cpu
230230
@@ -243,12 +243,12 @@ def inputs(self) -> tuple[list[str], list[str]]:
243243 return ["data" ], [self .text_field ]
244244
245245 def outputs (self ) -> tuple [list [str ], list [str ]]:
246- return ["data" ], [HIDDEN_TEXT_COLUMN ]
246+ return ["data" ], [HIDDEN_TEXT_FIELD ]
247247
248248 def _wrap_in_prompt (self , df : pd .DataFrame ) -> pd .DataFrame :
249249 documents = df [self .text_field ].tolist ()
250250 prompts = [format_aegis (doc [: self .max_chars ]) for doc in documents ]
251- df [HIDDEN_TEXT_COLUMN ] = prompts
251+ df [HIDDEN_TEXT_FIELD ] = prompts
252252 return df
253253
254254 def process (self , batch : DocumentBatch ) -> DocumentBatch :
@@ -272,16 +272,16 @@ class PostProcessAegisResponsesStage(ProcessingStage[DocumentBatch, DocumentBatc
272272
273273 cache_dir : str | None = None
274274 hf_token : str | None = None
275- pred_column : str = "aegis_pred"
276- raw_pred_column : str = "_aegis_raw_pred"
277- keep_raw_pred : bool = False
275+ label_field : str = "aegis_pred"
276+ raw_output_field : str = "_aegis_raw_pred"
277+ keep_raw_output : bool = False
278278 name = "postprocess_aegis_responses"
279279
280280 def inputs (self ) -> tuple [list [str ], list [str ]]:
281- return ["data" ], [self .raw_pred_column , HIDDEN_TEXT_COLUMN ]
281+ return ["data" ], [self .raw_output_field , HIDDEN_TEXT_FIELD ]
282282
283283 def outputs (self ) -> tuple [list [str ], list [str ]]:
284- return ["data" ], [self .pred_column ] + ([self .raw_pred_column ] if self .keep_raw_pred else [])
284+ return ["data" ], [self .label_field ] + ([self .raw_output_field ] if self .keep_raw_output else [])
285285
286286 def ray_stage_spec (self ) -> dict [str , Any ]:
287287 return {"is_actor_stage" : True }
@@ -331,27 +331,27 @@ def _parse_response(self, raw_response: str) -> str:
331331 return "unknown"
332332
333333 def _postprocess_responses (self , df : pd .DataFrame ) -> pd .DataFrame :
334- generated_tokens = df [self .raw_pred_column ].tolist ()
334+ generated_tokens = df [self .raw_output_field ].tolist ()
335335
336336 generated_tokens = self .tokenizer .batch_decode (
337337 generated_tokens ,
338338 skip_special_tokens = True ,
339339 )
340340
341- original_lengths = df [HIDDEN_TEXT_COLUMN ].str .len ().tolist ()
341+ original_lengths = df [HIDDEN_TEXT_FIELD ].str .len ().tolist ()
342342 generated_tokens = [
343343 chars [original_length :] for chars , original_length in zip (generated_tokens , original_lengths , strict = False )
344344 ]
345345 parsed_response = [self ._parse_response (response ) for response in generated_tokens ]
346346
347- if self .keep_raw_pred :
348- df [self .raw_pred_column ] = pd .Series (generated_tokens )
347+ if self .keep_raw_output :
348+ df [self .raw_output_field ] = pd .Series (generated_tokens )
349349 else :
350- df = df .drop (columns = [self .raw_pred_column ])
350+ df = df .drop (columns = [self .raw_output_field ])
351351
352- df [self .pred_column ] = pd .Series (parsed_response )
352+ df [self .label_field ] = pd .Series (parsed_response )
353353
354- return df .drop (columns = [HIDDEN_TEXT_COLUMN ])
354+ return df .drop (columns = [HIDDEN_TEXT_FIELD ])
355355
356356 def process (self , batch : DocumentBatch ) -> DocumentBatch :
357357 df = batch .to_pandas ()
@@ -388,10 +388,10 @@ class AegisClassifier(CompositeStage[DocumentBatch, DocumentBatch]):
388388 hf_token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
389389 needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
390390 Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
391- pred_column (str): The name of the column to store the resulting prediction. Defaults to "aegis_pred".
392- raw_pred_column (str): The name of the column to store the raw output of the AEGIS LLM before
391+ label_field (str): The name of the column to store the resulting prediction. Defaults to "aegis_pred".
392+ raw_output_field (str): The name of the column to store the raw output of the AEGIS LLM before
393393 the prediction is extracted from it. Defaults to "_aegis_raw_pred".
394- keep_raw_pred (bool): If True, will keep the unprocessed LLM output in raw_pred_column .
394+ keep_raw_output (bool): If True, will keep the unprocessed LLM output in raw_output_field .
395395 Useful for debugging when "unknown" shows up a lot in your dataset. Defaults to False.
396396 text_field (str): The field in the dataset that should be classified. Defaults to "text".
397397 filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
@@ -407,9 +407,9 @@ class AegisClassifier(CompositeStage[DocumentBatch, DocumentBatch]):
407407 aegis_variant : Literal [AEGIS_VARIANTS ] = AEGIS_VARIANTS [0 ]
408408 cache_dir : str | None = None
409409 hf_token : str | bool | None = None
410- pred_column : str = "aegis_pred"
411- raw_pred_column : str = "_aegis_raw_pred"
412- keep_raw_pred : bool = False
410+ label_field : str = "aegis_pred"
411+ raw_output_field : str = "_aegis_raw_pred"
412+ keep_raw_output : bool = False
413413 text_field : str = "text"
414414 filter_by : list [str ] | None = None
415415 max_chars : int = 6000
@@ -431,7 +431,7 @@ def __post_init__(self) -> None:
431431 model_identifier = PRETRAINED_MODEL_NAME_OR_PATH ,
432432 cache_dir = self .cache_dir ,
433433 hf_token = self .hf_token ,
434- text_field = HIDDEN_TEXT_COLUMN ,
434+ text_field = HIDDEN_TEXT_FIELD ,
435435 max_seq_length = MAX_SEQ_LENGTH ,
436436 padding_side = TOKENIZER_PADDING_SIDE ,
437437 sort_by_length = self .sort_by_length ,
@@ -441,7 +441,7 @@ def __post_init__(self) -> None:
441441 model_identifier = self .aegis_variant ,
442442 cache_dir = self .cache_dir ,
443443 hf_token = self .hf_token ,
444- pred_column = self .raw_pred_column ,
444+ label_field = self .raw_output_field ,
445445 model_inference_batch_size = self .model_inference_batch_size ,
446446 has_seq_order = self .sort_by_length ,
447447 add_instruction_data_guard = False ,
@@ -450,14 +450,14 @@ def __post_init__(self) -> None:
450450 PostProcessAegisResponsesStage (
451451 cache_dir = self .cache_dir ,
452452 hf_token = self .hf_token ,
453- pred_column = self .pred_column ,
454- raw_pred_column = self .raw_pred_column ,
455- keep_raw_pred = self .keep_raw_pred ,
453+ label_field = self .label_field ,
454+ raw_output_field = self .raw_output_field ,
455+ keep_raw_output = self .keep_raw_output ,
456456 ),
457457 ]
458458
459459 if self .filter_by is not None and len (self .filter_by ) > 0 :
460- self .stages .append (Filter (filter_fn = self .filter_by_category , filter_field = self .pred_column ))
460+ self .stages .append (Filter (filter_fn = self .filter_by_category , filter_field = self .label_field ))
461461
462462 def inputs (self ) -> tuple [list [str ], list [str ]]:
463463 return self .stages [0 ].inputs ()
@@ -519,8 +519,8 @@ class InstructionDataGuardClassifier(CompositeStage[DocumentBatch, DocumentBatch
519519 hf_token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
520520 needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
521521 Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
522- pred_column (str): The name of the column to store the resulting prediction. Defaults to "is_poisoned".
523- prob_column (str): The name of the column to store the poisoning probability score. Defaults to "instruction_data_guard_poisoning_score".
522+ label_field (str): The name of the column to store the resulting prediction. Defaults to "is_poisoned".
523+ score_field (str): The name of the column to store the poisoning probability score. Defaults to "instruction_data_guard_poisoning_score".
524524 text_field (str): The field in the dataset that should be classified. Defaults to "text".
525525 filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
526526 expect those specified in this list. Defaults to None.
@@ -534,8 +534,8 @@ class InstructionDataGuardClassifier(CompositeStage[DocumentBatch, DocumentBatch
534534
535535 cache_dir : str | None = None
536536 hf_token : str | bool | None = None
537- pred_column : str = "is_poisoned"
538- prob_column : str = "instruction_data_guard_poisoning_score"
537+ label_field : str = "is_poisoned"
538+ score_field : str = "instruction_data_guard_poisoning_score"
539539 text_field : str = "text"
540540 filter_by : list [str ] | None = None
541541 max_chars : int = 6000
@@ -564,8 +564,8 @@ def __post_init__(self) -> None:
564564 model_identifier = AEGIS_VARIANTS [0 ],
565565 cache_dir = self .cache_dir ,
566566 hf_token = self .hf_token ,
567- pred_column = self .pred_column ,
568- prob_column = self .prob_column ,
567+ label_field = self .label_field ,
568+ score_field = self .score_field ,
569569 model_inference_batch_size = self .model_inference_batch_size ,
570570 has_seq_order = self .sort_by_length ,
571571 add_instruction_data_guard = True ,
@@ -574,7 +574,7 @@ def __post_init__(self) -> None:
574574 ]
575575
576576 if self .filter_by is not None and len (self .filter_by ) > 0 :
577- self .stages .append (Filter (filter_fn = self .filter_by_category , filter_field = self .pred_column ))
577+ self .stages .append (Filter (filter_fn = self .filter_by_category , filter_field = self .label_field ))
578578
579579 def inputs (self ) -> tuple [list [str ], list [str ]]:
580580 return self .stages [0 ].inputs ()
0 commit comments