@@ -73,22 +73,22 @@ def __init__(
7373 world_size = torch .distributed .get_world_size ()
7474 model_name = self .cfg ["model_name" ]
7575 if self .cfg ["precision" ] == "float32" :
76- dtype = torch .float32
76+ self . dtype = torch .float32
7777 elif self .cfg ["precision" ] == "bfloat16" :
78- dtype = torch .bfloat16
78+ self . dtype = torch .bfloat16
7979 else :
8080 raise ValueError (f"Unknown precision: { self .cfg ['precision' ]} " )
8181
8282 print (f"[Rank { rank } ] Loading model { model_name } on CPU..." )
8383 self .model = AutoModelForCausalLM .from_pretrained (
8484 model_name ,
8585 device_map = "cpu" , # load weights onto CPU initially
86- torch_dtype = dtype , # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
86+ torch_dtype = torch . float32 , # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
8787 )
8888 self .reference_model = AutoModelForCausalLM .from_pretrained (
8989 model_name ,
9090 device_map = "cpu" , # load weights onto CPU initially
91- torch_dtype = dtype , # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
91+ torch_dtype = torch . float32 , # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
9292 )
9393
9494 self .tokenizer = AutoTokenizer .from_pretrained (model_name )
@@ -272,16 +272,17 @@ def train(
272272 # For right-padded sequence, set 1s at the beginning of the sequence
273273 attention_mask [i , :length ] = 1
274274
275- outputs = self .model (
276- input_ids = input_ids ,
277- attention_mask = attention_mask ,
278- use_cache = False ,
279- )
280- # Get logprobs
281- if not hasattr (outputs , "logits" ):
282- logits = self .model .lm_head (outputs .last_hidden_state )
283- else :
284- logits = outputs .logits
275+ with torch .autocast (device_type = "cuda" , dtype = self .dtype ):
276+ outputs = self .model (
277+ input_ids = input_ids ,
278+ attention_mask = attention_mask ,
279+ use_cache = False ,
280+ )
281+ # Get logprobs
282+ if not hasattr (outputs , "logits" ):
283+ logits = self .model .lm_head (outputs .last_hidden_state )
284+ else :
285+ logits = outputs .logits
285286
286287 loss , loss_metrics = loss_fn (logits , mb )
287288
@@ -358,11 +359,12 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict:
358359 attention_mask [i , :length ] = 1
359360
360361 # Process with the model directly using right-padded inputs
361- outputs = self .model (
362- input_ids = input_ids ,
363- attention_mask = attention_mask ,
364- use_cache = False ,
365- )
362+ with torch .autocast (device_type = "cuda" , dtype = self .dtype ):
363+ outputs = self .model (
364+ input_ids = input_ids ,
365+ attention_mask = attention_mask ,
366+ use_cache = False ,
367+ )
366368 log_probs = torch .nn .functional .log_softmax (
367369 outputs .logits .to (torch .float32 ), dim = - 1
368370 )
0 commit comments