You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adding support for apo losses, sppo_hard and nca_pair (#841)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This pr adds support for apo zero, apo down, sppo_hard and nca_pair loss
just like in
(https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py).
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
python -m pytest test/chunked_loss/test_dpo_loss.py
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: H100
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
---------
Co-authored-by: Manan Shah <[email protected]>
Co-authored-by: Vaibhav Jindal <[email protected]>
0 commit comments