Commit 066d525
authored
[NPU]: add support for grpo loss (#1049)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
To facilitate the integration of CI, we tested the chunked loss. Due to
the differences in the NPU devices, disabling the torch compilation was
necessary to pass most of the tests. However, some test cases of the
group loss operator failed. The root cause was that the parent class
LigerFusedLinearPPOBase did not convert the logits-related data to
float32 when calculating, while the NPU has errors when computing bf16
data. Therefore, we made modifications here to first support the CI
integration.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<img width="1700" height="532" alt="image"
src="https://github.com/user-attachments/assets/aa500e5d-c375-438d-b29a-135b13bb53b6"
/>
- Hardware Type: Atlas 800I A2
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence1 parent 60f6c84 commit 066d525
1 file changed
+2
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
170 | 170 | | |
171 | 171 | | |
172 | 172 | | |
173 | | - | |
| 173 | + | |
174 | 174 | | |
175 | 175 | | |
176 | 176 | | |
| |||
414 | 414 | | |
415 | 415 | | |
416 | 416 | | |
417 | | - | |
| 417 | + | |
418 | 418 | | |
419 | 419 | | |
420 | 420 | | |
| |||
0 commit comments