@@ -267,25 +267,25 @@ def forward(
267267class  STEP1TextEncoder (torch .nn .Module ):
268268
269269    def  __init__ (self , model_dir , max_length = 320 ):
270-         super (STEP1TextEncoder ,  self ). __init__ ( )
270+         super ()
271271        self .max_length  =  max_length 
272272        self .text_tokenizer  =  Wrapped_StepChatTokenizer (os .path .join (model_dir , 'step1_chat_tokenizer.model' ))
273273        text_encoder  =  Step1Model .from_pretrained (model_dir )
274274        self .text_encoder  =  text_encoder .eval ().to (torch .bfloat16 )
275275
276276    @torch .no_grad  
277+     @torch .autocast (device_type = 'cuda' , dtype = torch .bfloat16 ) 
277278    def  forward (self , prompts , with_mask = True , max_length = None ):
278279        self .device  =  next (self .text_encoder .parameters ()).device 
279-         with  torch .no_grad (), torch .cuda .amp .autocast (dtype = torch .bfloat16 ):
280-             if  type (prompts ) is  str :
281-                 prompts  =  [prompts ]
282- 
283-             txt_tokens  =  self .text_tokenizer (prompts ,
284-                                              max_length = max_length  or  self .max_length ,
285-                                              padding = "max_length" ,
286-                                              truncation = True ,
287-                                              return_tensors = "pt" )
288-             y  =  self .text_encoder (txt_tokens .input_ids .to (self .device ),
280+         if  type (prompts ) is  str :
281+             prompts  =  [prompts ]
282+ 
283+         txt_tokens  =  self .text_tokenizer (prompts ,
284+                                             max_length = max_length  or  self .max_length ,
285+                                             padding = "max_length" ,
286+                                             truncation = True ,
287+                                             return_tensors = "pt" )
288+         y  =  self .text_encoder (txt_tokens .input_ids .to (self .device ),
289289                                  attention_mask = txt_tokens .attention_mask .to (self .device ) if  with_mask  else  None )
290-              y_mask  =  txt_tokens .attention_mask 
290+         y_mask  =  txt_tokens .attention_mask 
291291        return  y .transpose (0 , 1 ), y_mask 
0 commit comments