Skip to content

Commit 066d525

Browse files
authored
[NPU]: add support for grpo loss (#1049)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> To facilitate the integration of CI, we tested the chunked loss. Due to the differences in the NPU devices, disabling the torch compilation was necessary to pass most of the tests. However, some test cases of the group loss operator failed. The root cause was that the parent class LigerFusedLinearPPOBase did not convert the logits-related data to float32 when calculating, while the NPU has errors when computing bf16 data. Therefore, we made modifications here to first support the CI integration. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <img width="1700" height="532" alt="image" src="https://github.com/user-attachments/assets/aa500e5d-c375-438d-b29a-135b13bb53b6" /> - Hardware Type: Atlas 800I A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 60f6c84 commit 066d525

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/chunked_loss/test_grpo_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def forward(
170170
):
171171
logits = x @ self.lin.weight.t()
172172
if self.lin.bias is not None:
173-
logits = logits + self.lin.bias.float()
173+
logits = logits + self.lin.bias
174174
if self.temperature != 1.0:
175175
logits = logits / self.temperature
176176
# Get log probabilities
@@ -414,7 +414,7 @@ def test_correctness(
414414
if torch_lm_head_grpo.lin.bias is not None:
415415
logits = logits + torch_lm_head_grpo.lin.bias
416416
logits = logits / temperature
417-
logps = F.log_softmax(logits.float(), dim=-1)
417+
logps = F.log_softmax(logits, dim=-1)
418418
per_token_logps = logps.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1)
419419

420420
# Create attention mask with random padding [B, T]

0 commit comments

Comments
 (0)