Skip to content

Commit 6b3dc31

Browse files
authored
feat: Enable amp with autocast (fix poor bf16 convergence on GRPO (#26)
Signed-off-by: Sahil Jain <sahilj@nvidia.com>
1 parent 04f4d16 commit 6b3dc31

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

nemo_reinforcer/models/policy/hf_policy.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,22 @@ def __init__(
7373
world_size = torch.distributed.get_world_size()
7474
model_name = self.cfg["model_name"]
7575
if self.cfg["precision"] == "float32":
76-
dtype = torch.float32
76+
self.dtype = torch.float32
7777
elif self.cfg["precision"] == "bfloat16":
78-
dtype = torch.bfloat16
78+
self.dtype = torch.bfloat16
7979
else:
8080
raise ValueError(f"Unknown precision: {self.cfg['precision']}")
8181

8282
print(f"[Rank {rank}] Loading model {model_name} on CPU...")
8383
self.model = AutoModelForCausalLM.from_pretrained(
8484
model_name,
8585
device_map="cpu", # load weights onto CPU initially
86-
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
86+
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
8787
)
8888
self.reference_model = AutoModelForCausalLM.from_pretrained(
8989
model_name,
9090
device_map="cpu", # load weights onto CPU initially
91-
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
91+
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
9292
)
9393

9494
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -272,16 +272,17 @@ def train(
272272
# For right-padded sequence, set 1s at the beginning of the sequence
273273
attention_mask[i, :length] = 1
274274

275-
outputs = self.model(
276-
input_ids=input_ids,
277-
attention_mask=attention_mask,
278-
use_cache=False,
279-
)
280-
# Get logprobs
281-
if not hasattr(outputs, "logits"):
282-
logits = self.model.lm_head(outputs.last_hidden_state)
283-
else:
284-
logits = outputs.logits
275+
with torch.autocast(device_type="cuda", dtype=self.dtype):
276+
outputs = self.model(
277+
input_ids=input_ids,
278+
attention_mask=attention_mask,
279+
use_cache=False,
280+
)
281+
# Get logprobs
282+
if not hasattr(outputs, "logits"):
283+
logits = self.model.lm_head(outputs.last_hidden_state)
284+
else:
285+
logits = outputs.logits
285286

286287
loss, loss_metrics = loss_fn(logits, mb)
287288

@@ -358,11 +359,12 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict:
358359
attention_mask[i, :length] = 1
359360

360361
# Process with the model directly using right-padded inputs
361-
outputs = self.model(
362-
input_ids=input_ids,
363-
attention_mask=attention_mask,
364-
use_cache=False,
365-
)
362+
with torch.autocast(device_type="cuda", dtype=self.dtype):
363+
outputs = self.model(
364+
input_ids=input_ids,
365+
attention_mask=attention_mask,
366+
use_cache=False,
367+
)
366368
log_probs = torch.nn.functional.log_softmax(
367369
outputs.logits.to(torch.float32), dim=-1
368370
)

0 commit comments

Comments
 (0)