Skip to content

Commit 8993985

Browse files
Improve text embedding generation (#1064)
1 parent 1f91290 commit 8993985

File tree

7 files changed

+59
-81
lines changed

7 files changed

+59
-81
lines changed

nemo_curator/stages/text/classifiers/aegis.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __init__( # noqa: PLR0913
8585
local_files_only: bool = True,
8686
hf_token: str | bool | None = None,
8787
add_instruction_data_guard: bool = False,
88-
autocast: bool = False,
8988
):
9089
super().__init__()
9190

@@ -107,7 +106,6 @@ def __init__( # noqa: PLR0913
107106
cache_dir=cache_dir,
108107
local_files_only=local_files_only,
109108
)
110-
self.autocast = autocast
111109
self.add_instruction_data_guard = add_instruction_data_guard
112110
if self.add_instruction_data_guard:
113111
self.instruction_data_guard_net = InstructionDataGuardNet(4096)
@@ -117,7 +115,7 @@ def device(self) -> torch.device:
117115
return next(self.parameters()).device
118116

119117
@torch.no_grad()
120-
def _forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
118+
def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
121119
batch = {k: v.to(TORCH_DTYPE) if v.dtype.is_floating_point else v for k, v in batch.items()}
122120

123121
if self.add_instruction_data_guard:
@@ -145,14 +143,6 @@ def _forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
145143

146144
return response
147145

148-
@torch.no_grad()
149-
def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
150-
if self.autocast:
151-
with torch.autocast(device_type="cuda"):
152-
return self._forward(batch)
153-
else:
154-
return self._forward(batch)
155-
156146

157147
class AegisModelStage(ModelStage):
158148
"""
@@ -179,12 +169,12 @@ def __init__( # noqa: PLR0913
179169
has_seq_order=has_seq_order,
180170
padding_side=TOKENIZER_PADDING_SIDE,
181171
unpack_inference_batch=False,
172+
autocast=autocast,
182173
)
183174

184175
self.add_instruction_data_guard = add_instruction_data_guard
185176
self.pred_column = pred_column
186177
self.prob_column = prob_column
187-
self.autocast = autocast
188178

189179
def outputs(self) -> tuple[list[str], list[str]]:
190180
return ["data"], [self.pred_column] + ([self.prob_column] if self.add_instruction_data_guard else [])
@@ -199,7 +189,6 @@ def _setup(self, local_files_only: bool = True) -> None:
199189
local_files_only=local_files_only,
200190
hf_token=self.hf_token,
201191
add_instruction_data_guard=self.add_instruction_data_guard,
202-
autocast=self.autocast,
203192
)
204193
if self.add_instruction_data_guard:
205194
self.model.instruction_data_guard_net = self.model.instruction_data_guard_net.from_pretrained(

nemo_curator/stages/text/classifiers/base.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def device(self) -> torch.device:
5353
return next(self.parameters()).device
5454

5555
@torch.no_grad()
56-
def _forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
56+
def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
5757
features = self.model(batch[INPUT_ID_COLUMN], batch[ATTENTION_MASK_COLUMN]).last_hidden_state
5858
dropped = self.dropout(features)
5959
outputs = self.fc(dropped)
@@ -62,17 +62,6 @@ def _forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
6262

6363
return torch.softmax(outputs[:, 0, :], dim=1)
6464

65-
@torch.no_grad()
66-
def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
67-
if self.autocast:
68-
with torch.autocast(device_type="cuda"):
69-
return self._forward(batch)
70-
else:
71-
return self._forward(batch)
72-
73-
def set_autocast(self, autocast: bool) -> None:
74-
self.autocast = autocast
75-
7665

7766
class ClassifierModelStage(ModelStage):
7867
"""
@@ -109,6 +98,7 @@ def __init__( # noqa: PLR0913
10998
model_inference_batch_size=model_inference_batch_size,
11099
padding_side=padding_side,
111100
unpack_inference_batch=False,
101+
autocast=autocast,
112102
)
113103

114104
self.pred_column = pred_column
@@ -118,16 +108,20 @@ def __init__( # noqa: PLR0913
118108
else:
119109
self.prob_column = "probs"
120110
self.keep_prob_column = False
121-
self.autocast = autocast
122111

123112
def outputs(self) -> tuple[list[str], list[str]]:
124113
return ["data"], [self.pred_column] + ([self.prob_column] if self.keep_prob_column else [])
125114

126115
def _setup(self, local_files_only: bool = True) -> None:
127-
self.model = Deberta.from_pretrained(self.model_identifier, cache_dir=self.cache_dir, local_files_only=local_files_only).cuda().eval()
128-
self.model.set_autocast(self.autocast)
116+
self.model = (
117+
Deberta.from_pretrained(self.model_identifier, cache_dir=self.cache_dir, local_files_only=local_files_only)
118+
.cuda()
119+
.eval()
120+
)
129121

130-
config = AutoConfig.from_pretrained(self.model_identifier, cache_dir=self.cache_dir, local_files_only=local_files_only)
122+
config = AutoConfig.from_pretrained(
123+
self.model_identifier, cache_dir=self.cache_dir, local_files_only=local_files_only
124+
)
131125
self.labels = list(config.label2id.keys())
132126
self.labels.sort(key=lambda x: config.label2id[x])
133127

nemo_curator/stages/text/classifiers/fineweb_edu.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,13 @@ def outputs(self) -> tuple[list[str], list[str]]:
8484
return ["data"], [self.pred_column, self.float_score_column, self.int_score_column]
8585

8686
@staticmethod
87-
def configure_forward(model: torch.nn.Module, autocast: bool = True) -> torch.nn.Module:
87+
def configure_forward(model: torch.nn.Module) -> torch.nn.Module:
8888
original_forward = model.forward
8989

9090
@torch.no_grad()
9191
def custom_forward(*args, **kwargs) -> torch.Tensor:
92-
if autocast:
93-
with torch.autocast(device_type="cuda"):
94-
output = original_forward(*args, **kwargs)
95-
else:
96-
output = original_forward(*args, **kwargs)
97-
92+
output = original_forward(*args, **kwargs)
9893
del args, kwargs
99-
10094
return output.logits.squeeze(-1).float()
10195

10296
model.forward = custom_forward
@@ -108,9 +102,11 @@ def _setup(self, local_files_only: bool = True) -> None:
108102
cache_dir=self.cache_dir,
109103
local_files_only=local_files_only,
110104
).cuda()
111-
self.model = self.configure_forward(model, self.autocast)
105+
self.model = self.configure_forward(model)
112106

113-
def process_model_output(self, outputs: torch.Tensor, _: dict[str, torch.Tensor] | None = None) -> dict[str, np.ndarray]:
107+
def process_model_output(
108+
self, outputs: torch.Tensor, _: dict[str, torch.Tensor] | None = None
109+
) -> dict[str, np.ndarray]:
114110
logits = outputs.cpu().numpy()
115111

116112
float_scores = logits.tolist()

nemo_curator/stages/text/classifiers/prompt_task_complexity.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,7 @@ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
210210
input_ids = batch[INPUT_ID_COLUMN]
211211
attention_mask = batch[ATTENTION_MASK_COLUMN]
212212

213-
if self.autocast:
214-
with torch.autocast(device_type="cuda"):
215-
return self._forward(input_ids, attention_mask)
216-
else:
217-
return self._forward(input_ids, attention_mask)
218-
219-
def set_autocast(self, autocast: bool) -> None:
220-
self.autocast = autocast
213+
return self._forward(input_ids, attention_mask)
221214

222215

223216
class PromptTaskComplexityModelStage(ModelStage):
@@ -256,12 +249,15 @@ def outputs(self) -> tuple[list[str], list[str]]:
256249
return ["data"], OUTPUT_COLUMNS
257250

258251
def _setup(self, local_files_only: bool = True) -> None:
259-
self.model = CustomDeberta.from_pretrained(
260-
self.model_identifier,
261-
cache_dir=self.cache_dir,
262-
local_files_only=local_files_only,
263-
).cuda().eval()
264-
self.model.set_autocast(self.autocast)
252+
self.model = (
253+
CustomDeberta.from_pretrained(
254+
self.model_identifier,
255+
cache_dir=self.cache_dir,
256+
local_files_only=local_files_only,
257+
)
258+
.cuda()
259+
.eval()
260+
)
265261

266262
def process_model_output(self, outputs: torch.Tensor, _: dict[str, torch.Tensor] | None = None) -> torch.Tensor:
267263
return outputs

nemo_curator/stages/text/embedders/base.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from dataclasses import dataclass
1616
from typing import Literal
1717

18-
import cudf
19-
import cupy as cp
2018
import pandas as pd
2119
import torch
2220
import torch.nn.functional as F # noqa: N812
@@ -29,8 +27,6 @@
2927
from nemo_curator.stages.text.models.utils import ATTENTION_MASK_COLUMN
3028
from nemo_curator.tasks import DocumentBatch
3129

32-
from .utils import create_list_series_from_1d_or_2d_ar
33-
3430

3531
class EmbeddingModelStage(ModelStage):
3632
"""HuggingFace model stage that produces embeddings with pooling."""
@@ -41,9 +37,10 @@ def __init__( # noqa: PLR0913
4137
embedding_field: str = "embeddings",
4238
pooling: Literal["mean_pooling", "last_token"] = "mean_pooling",
4339
hf_token: str | None = None,
44-
model_inference_batch_size: int = 256,
40+
model_inference_batch_size: int = 1024,
4541
has_seq_order: bool = True,
4642
padding_side: Literal["left", "right"] = "right",
43+
autocast: bool = True,
4744
):
4845
super().__init__(
4946
model_identifier=model_identifier,
@@ -52,6 +49,7 @@ def __init__( # noqa: PLR0913
5249
has_seq_order=has_seq_order,
5350
padding_side=padding_side,
5451
unpack_inference_batch=True,
52+
autocast=autocast,
5553
)
5654
self.embedding_field = embedding_field
5755
self.pooling = pooling
@@ -62,33 +60,23 @@ def outputs(self) -> tuple[list[str], list[str]]:
6260
def setup(self, _: WorkerMetadata | None = None) -> None:
6361
"""Load the model for inference."""
6462
self.model = AutoModel.from_pretrained(self.model_identifier, local_files_only=True)
65-
self.model.eval()
66-
self.model.to("cuda")
63+
self.model.eval().to("cuda")
6764

6865
def process_model_output(
6966
self, outputs: torch.Tensor, model_input_batch: dict[str, torch.Tensor] | None = None
7067
) -> torch.Tensor:
7168
"""Process model outputs to create embeddings."""
7269
if self.pooling == "mean_pooling":
73-
return self._mean_pooling(outputs, model_input_batch[ATTENTION_MASK_COLUMN])
70+
return self._mean_pooling(outputs, model_input_batch[ATTENTION_MASK_COLUMN]).cpu()
7471
else:
75-
return self._get_last_token(outputs, model_input_batch[ATTENTION_MASK_COLUMN])
72+
return self._get_last_token(outputs, model_input_batch[ATTENTION_MASK_COLUMN]).cpu()
7673

77-
def collect_outputs(self, processed_outputs: list[torch.Tensor]) -> cp.ndarray:
78-
"""Collect embeddings into a cupy array."""
79-
# TODO : benchmarking this and maybe stay in cpu land
80-
cupy_array_embeddings = [cp.asarray(emb_chunk) for emb_chunk in processed_outputs]
81-
return cp.concatenate(cupy_array_embeddings, axis=0)
74+
def collect_outputs(self, processed_outputs: list[torch.Tensor]) -> list[list[float]]:
75+
return torch.cat(processed_outputs, dim=0).numpy().tolist()
8276

83-
def create_output_dataframe(self, df_cpu: pd.DataFrame, collected_output: cp.ndarray) -> pd.DataFrame:
77+
def create_output_dataframe(self, df_cpu: pd.DataFrame, collected_output: list[list[float]]) -> pd.DataFrame:
8478
"""Create output dataframe with embeddings."""
85-
# TODO: Consider if it even makes sense to goto cudf or just concat in numpy
86-
df_gpu = cudf.DataFrame(index=df_cpu.index)
87-
df_gpu[self.embedding_field] = create_list_series_from_1d_or_2d_ar(collected_output, index=df_gpu.index)
88-
# Add embedding_field back to cpu dataframe
89-
df_cpu[self.embedding_field] = df_gpu[self.embedding_field].to_pandas()
90-
del df_gpu
91-
return df_cpu
79+
return df_cpu.assign(**{self.embedding_field: collected_output})
9280

9381
def _mean_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
9482
token_embeddings = model_output[0]
@@ -119,7 +107,8 @@ class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
119107
max_seq_length: int | None = None
120108
padding_side: Literal["left", "right"] = "right"
121109
embedding_pooling: Literal["mean_pooling", "last_token"] = "mean_pooling"
122-
model_inference_batch_size: int = 256
110+
model_inference_batch_size: int = 1024
111+
autocast: bool = True
123112
sort_by_length: bool = True
124113
hf_token: str | None = None
125114

@@ -144,6 +133,7 @@ def __post_init__(self) -> None:
144133
model_inference_batch_size=self.model_inference_batch_size,
145134
has_seq_order=self.sort_by_length,
146135
padding_side=self.padding_side,
136+
autocast=self.autocast,
147137
),
148138
]
149139

nemo_curator/stages/text/models/model.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class ModelStage(ProcessingStage[DocumentBatch, DocumentBatch]):
4646
Sorting is encouraged to improve the performance of the inference model. Defaults to True.
4747
padding_side: The side to pad the input tokens. Defaults to "right".
4848
unpack_inference_batch: Whether to unpack the inference batch with **kwargs. Defaults to False.
49+
autocast: Whether to use autocast. When True, we trade off minor accuracy for faster inference.
50+
Defaults to True.
4951
5052
"""
5153

@@ -58,6 +60,7 @@ def __init__( # noqa: PLR0913
5860
has_seq_order: bool = True,
5961
padding_side: Literal["left", "right"] = "right",
6062
unpack_inference_batch: bool = False,
63+
autocast: bool = True,
6164
):
6265
self._name = format_name_with_suffix(model_identifier, suffix="_model")
6366
# Assume that the model can fit on a single GPU
@@ -70,6 +73,7 @@ def __init__( # noqa: PLR0913
7073
self.has_seq_order = has_seq_order
7174
self.padding_side = padding_side
7275
self.unpack_inference_batch = unpack_inference_batch
76+
self.autocast = autocast
7377

7478
def inputs(self) -> tuple[list[str], list[str]]:
7579
return ["data"], [INPUT_ID_COLUMN, ATTENTION_MASK_COLUMN] + ([SEQ_ORDER_COLUMN] if self.has_seq_order else [])
@@ -147,17 +151,24 @@ def create_output_dataframe(self, df_cpu: pd.DataFrame, collected_output: dict[s
147151
msg = "Subclasses must implement this method"
148152
raise NotImplementedError(msg)
149153

154+
def _model_forward(self, model_input_batch: dict[str, torch.Tensor]) -> torch.Tensor:
155+
if self.unpack_inference_batch:
156+
return self.model(**model_input_batch)
157+
else:
158+
return self.model(model_input_batch)
159+
150160
def process(self, batch: DocumentBatch) -> DocumentBatch:
151161
df_cpu = batch.to_pandas()
152162

153163
processed_outputs = []
154164
for model_input_batch in self.yield_next_batch(df_cpu):
155165
# Forward pass
156166
with torch.no_grad():
157-
if self.unpack_inference_batch:
158-
outputs = self.model(**model_input_batch)
167+
if self.autocast:
168+
with torch.autocast(device_type="cuda"):
169+
outputs = self._model_forward(model_input_batch)
159170
else:
160-
outputs = self.model(model_input_batch)
171+
outputs = self._model_forward(model_input_batch)
161172

162173
processed_output = self.process_model_output(outputs, model_input_batch)
163174
del model_input_batch

tests/stages/text/embedders/test_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,17 @@ def test_embedding_creator_stage_process_integration(self) -> None:
292292
assert embedding_stage.pooling == stage.embedding_pooling
293293

294294
@pytest.mark.parametrize("pooling_strategy", ["mean_pooling", "last_token"])
295+
@pytest.mark.parametrize("autocast", [True, False])
295296
@pytest.mark.gpu
296297
def test_embedding_creator_stage_with_reference_embeddings(
297-
self, pooling_strategy: str, sample_data: DocumentBatch
298+
self, pooling_strategy: str, sample_data: DocumentBatch, autocast: bool
298299
) -> None:
299300
"""Test embeddings match reference implementation (requires GPU and model download)."""
300301
stage = EmbeddingCreatorStage(
301302
model_identifier="sentence-transformers/all-MiniLM-L6-v2",
302303
embedding_pooling=pooling_strategy,
303304
model_inference_batch_size=32,
305+
autocast=autocast,
304306
)
305307

306308
# Decompose and setup stages
@@ -362,7 +364,7 @@ def _get_reference_embeddings(
362364
)
363365
inputs = {k: v.to("cuda") for k, v in inputs.items()}
364366

365-
with torch.no_grad(), torch.autocast(device_type="cuda"):
367+
with torch.no_grad():
366368
outputs = model(**inputs)
367369

368370
if pooling_strategy == "last_token":

0 commit comments

Comments
 (0)