Skip to content

Commit 3483ee4

Browse files
authored
Minor CrossFit improvements (#483)
Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
1 parent a43abdd commit 3483ee4

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

nemo_curator/classifiers/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,13 @@ def _run_classifier_helper(
123123
prob_col: str = None,
124124
) -> "dask_cudf.DataFrame":
125125

126-
if prob_col:
127-
df[prob_col] = 0
128-
else:
126+
if prob_col is None:
129127
prob_col = "_prob"
128+
labeler = op.Labeler(labels, cols=[prob_col], suffix=label_col)
129+
else:
130+
labeler = op.Labeler(
131+
labels, cols=[prob_col], keep_cols=[prob_col], suffix=label_col
132+
)
130133

131134
columns_to_keep_list = df.columns.to_list()
132135

@@ -140,7 +143,7 @@ def _run_classifier_helper(
140143
batch_size=batch_size,
141144
pred_output_col=prob_col,
142145
),
143-
op.Labeler(labels, cols=[prob_col], suffix=label_col),
146+
labeler,
144147
repartition=df.npartitions,
145148
keep_cols=columns_to_keep_list,
146149
)

nemo_curator/classifiers/prompt_task_complexity.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,15 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
337337

338338
df = dataset.df
339339
columns_to_keep_list = df.columns.to_list()
340-
df["sliced_text"] = df[self.text_field].str.slice(0, self.max_chars)
341340

342341
model = self.model
343342
classifier_pipe = op.Sequential(
344-
op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"),
343+
op.Tokenizer(
344+
model,
345+
cols=[self.text_field],
346+
tokenizer_type="default",
347+
max_chars=self.max_chars,
348+
),
345349
op.Predictor(
346350
model,
347351
sorted_data_loader=True,

0 commit comments

Comments
 (0)