1515from dataclasses import dataclass
1616from typing import Literal
1717
18- import cudf
19- import cupy as cp
2018import pandas as pd
2119import torch
2220import torch .nn .functional as F # noqa: N812
2927from nemo_curator .stages .text .models .utils import ATTENTION_MASK_COLUMN
3028from nemo_curator .tasks import DocumentBatch
3129
32- from .utils import create_list_series_from_1d_or_2d_ar
33-
3430
3531class EmbeddingModelStage (ModelStage ):
3632 """HuggingFace model stage that produces embeddings with pooling."""
@@ -41,9 +37,10 @@ def __init__( # noqa: PLR0913
4137 embedding_field : str = "embeddings" ,
4238 pooling : Literal ["mean_pooling" , "last_token" ] = "mean_pooling" ,
4339 hf_token : str | None = None ,
44- model_inference_batch_size : int = 256 ,
40+ model_inference_batch_size : int = 1024 ,
4541 has_seq_order : bool = True ,
4642 padding_side : Literal ["left" , "right" ] = "right" ,
43+ autocast : bool = True ,
4744 ):
4845 super ().__init__ (
4946 model_identifier = model_identifier ,
@@ -52,6 +49,7 @@ def __init__( # noqa: PLR0913
5249 has_seq_order = has_seq_order ,
5350 padding_side = padding_side ,
5451 unpack_inference_batch = True ,
52+ autocast = autocast ,
5553 )
5654 self .embedding_field = embedding_field
5755 self .pooling = pooling
@@ -62,33 +60,23 @@ def outputs(self) -> tuple[list[str], list[str]]:
6260 def setup (self , _ : WorkerMetadata | None = None ) -> None :
6361 """Load the model for inference."""
6462 self .model = AutoModel .from_pretrained (self .model_identifier , local_files_only = True )
65- self .model .eval ()
66- self .model .to ("cuda" )
63+ self .model .eval ().to ("cuda" )
6764
6865 def process_model_output (
6966 self , outputs : torch .Tensor , model_input_batch : dict [str , torch .Tensor ] | None = None
7067 ) -> torch .Tensor :
7168 """Process model outputs to create embeddings."""
7269 if self .pooling == "mean_pooling" :
73- return self ._mean_pooling (outputs , model_input_batch [ATTENTION_MASK_COLUMN ])
70+ return self ._mean_pooling (outputs , model_input_batch [ATTENTION_MASK_COLUMN ]). cpu ()
7471 else :
75- return self ._get_last_token (outputs , model_input_batch [ATTENTION_MASK_COLUMN ])
72+ return self ._get_last_token (outputs , model_input_batch [ATTENTION_MASK_COLUMN ]). cpu ()
7673
77- def collect_outputs (self , processed_outputs : list [torch .Tensor ]) -> cp .ndarray :
78- """Collect embeddings into a cupy array."""
79- # TODO : benchmarking this and maybe stay in cpu land
80- cupy_array_embeddings = [cp .asarray (emb_chunk ) for emb_chunk in processed_outputs ]
81- return cp .concatenate (cupy_array_embeddings , axis = 0 )
74+ def collect_outputs (self , processed_outputs : list [torch .Tensor ]) -> list [list [float ]]:
75+ return torch .cat (processed_outputs , dim = 0 ).numpy ().tolist ()
8276
83- def create_output_dataframe (self , df_cpu : pd .DataFrame , collected_output : cp . ndarray ) -> pd .DataFrame :
77+ def create_output_dataframe (self , df_cpu : pd .DataFrame , collected_output : list [ list [ float ]] ) -> pd .DataFrame :
8478 """Create output dataframe with embeddings."""
85- # TODO: Consider if it even makes sense to goto cudf or just concat in numpy
86- df_gpu = cudf .DataFrame (index = df_cpu .index )
87- df_gpu [self .embedding_field ] = create_list_series_from_1d_or_2d_ar (collected_output , index = df_gpu .index )
88- # Add embedding_field back to cpu dataframe
89- df_cpu [self .embedding_field ] = df_gpu [self .embedding_field ].to_pandas ()
90- del df_gpu
91- return df_cpu
79+ return df_cpu .assign (** {self .embedding_field : collected_output })
9280
9381 def _mean_pooling (self , model_output : torch .Tensor , attention_mask : torch .Tensor ) -> torch .Tensor :
9482 token_embeddings = model_output [0 ]
@@ -119,7 +107,8 @@ class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
119107 max_seq_length : int | None = None
120108 padding_side : Literal ["left" , "right" ] = "right"
121109 embedding_pooling : Literal ["mean_pooling" , "last_token" ] = "mean_pooling"
122- model_inference_batch_size : int = 256
110+ model_inference_batch_size : int = 1024
111+ autocast : bool = True
123112 sort_by_length : bool = True
124113 hf_token : str | None = None
125114
@@ -144,6 +133,7 @@ def __post_init__(self) -> None:
144133 model_inference_batch_size = self .model_inference_batch_size ,
145134 has_seq_order = self .sort_by_length ,
146135 padding_side = self .padding_side ,
136+ autocast = self .autocast ,
147137 ),
148138 ]
149139
0 commit comments