1818from forge .api .types import (
1919 ForwardBackwardResult ,
2020 ForwardResult ,
21+ LossFn ,
2122 OptimStepResult ,
2223 TextTrainBatch ,
2324 TrainerInfo ,
2930class Trainer (Protocol ):
3031 """Protocol defining the standard interface for all Forge trainers."""
3132
32- async def forward_backward (self , batch : TextTrainBatch ) -> ForwardBackwardResult :
33+ async def forward_backward (
34+ self , batch : TextTrainBatch , loss_fn : LossFn | None = None
35+ ) -> ForwardBackwardResult :
3336 """Execute forward pass and backward pass for one batch of data.
3437
3538 Basic usage - single batch per optimizer step:
@@ -45,16 +48,26 @@ async def forward_backward(self, batch: TextTrainBatch) -> ForwardBackwardResult
4548 >>> await trainer.forward_backward(batch2) # Accumulates another batch
4649 >>> await trainer.optim_step() # Apply all accumulated gradients
4750
51+ Custom loss function for specific batches:
52+ >>> def custom_loss(logits: torch.Tensor, batch: TextTrainBatch) -> torch.Tensor:
53+ >>> # Custom loss computation (e.g., PPO clip, DPO, etc.)
54+ >>> return loss
55+ >>>
56+ >>> result = await trainer.forward_backward(batch, loss_fn=custom_loss)
57+
4858 Args:
4959 batch: TextTrainBatch containing input_ids, target_ids, and optional
5060 target_mask/target_weights. See forge.api.types.TextTrainBatch for details.
61+ loss_fn: Optional custom loss function. If None, uses the loss function
62+ configured at trainer creation. Signature: (logits, batch) -> loss.
63+ Useful for mixed training objectives or experimentation.
5164
5265 Returns:
5366 ForwardBackwardResult containing loss and metrics
5467
5568 Note:
56- The loss function is configured at trainer creation time via the
57- `loss` parameter, not passed to this method .
69+ The default loss function is configured at trainer creation time via the
70+ `loss` parameter. The `loss_fn` parameter here allows per-batch override .
5871 """
5972 ...
6073
0 commit comments