Skip to content

per-token KL penalty from the SFT model while doing the PPO training #2608

@MXuer

Description

@MXuer
  • I can't find the part for "per-token KL penalty from the SFT model" during the PPO training in the file model/model_training/trainer_rl.py, maybe I missed something. Could you tell me how these two loss combined?
  • I found the loss function "PolyLoss" in the model/model_training/losses.py. Is this the loss function for the "per-token KL penalty from the SFT model" part? If so, I am wondering why there is a CE function combined?

Thanks a lot.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions