Skip to content

Commit 4d11e4e

Browse files
author
Allen Wang
committed
add custom loss
1 parent 83f3a8f commit 4d11e4e

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

src/forge/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from forge.api.types import (
1414
ForwardBackwardResult,
1515
ForwardResult,
16+
LossFn,
1617
OptimStepResult,
1718
TextTrainBatch,
1819
TrainerInfo,
@@ -27,4 +28,5 @@
2728
"ForwardResult",
2829
"TrainerInfo",
2930
"TrainerStatus",
31+
"LossFn",
3032
]

src/forge/api/trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from forge.api.types import (
1919
ForwardBackwardResult,
2020
ForwardResult,
21+
LossFn,
2122
OptimStepResult,
2223
TextTrainBatch,
2324
TrainerInfo,
@@ -29,7 +30,9 @@
2930
class Trainer(Protocol):
3031
"""Protocol defining the standard interface for all Forge trainers."""
3132

32-
async def forward_backward(self, batch: TextTrainBatch) -> ForwardBackwardResult:
33+
async def forward_backward(
34+
self, batch: TextTrainBatch, loss_fn: LossFn | None = None
35+
) -> ForwardBackwardResult:
3336
"""Execute forward pass and backward pass for one batch of data.
3437
3538
Basic usage - single batch per optimizer step:
@@ -45,16 +48,26 @@ async def forward_backward(self, batch: TextTrainBatch) -> ForwardBackwardResult
4548
>>> await trainer.forward_backward(batch2) # Accumulates another batch
4649
>>> await trainer.optim_step() # Apply all accumulated gradients
4750
51+
Custom loss function for specific batches:
52+
>>> def custom_loss(logits: torch.Tensor, batch: TextTrainBatch) -> torch.Tensor:
53+
>>> # Custom loss computation (e.g., PPO clip, DPO, etc.)
54+
>>> return loss
55+
>>>
56+
>>> result = await trainer.forward_backward(batch, loss_fn=custom_loss)
57+
4858
Args:
4959
batch: TextTrainBatch containing input_ids, target_ids, and optional
5060
target_mask/target_weights. See forge.api.types.TextTrainBatch for details.
61+
loss_fn: Optional custom loss function. If None, uses the loss function
62+
configured at trainer creation. Signature: (logits, batch) -> loss.
63+
Useful for mixed training objectives or experimentation.
5164
5265
Returns:
5366
ForwardBackwardResult containing loss and metrics
5467
5568
Note:
56-
The loss function is configured at trainer creation time via the
57-
`loss` parameter, not passed to this method.
69+
The default loss function is configured at trainer creation time via the
70+
`loss` parameter. The `loss_fn` parameter here allows per-batch override.
5871
"""
5972
...
6073

src/forge/api/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
"""Type definitions for the Forge API."""
88

99
from dataclasses import dataclass
10-
from typing import Any
10+
from typing import Any, Callable, TypeAlias
1111

1212
import torch
1313

1414

15+
# Loss function signature: takes logits and batch, returns scalar loss
16+
LossFn: TypeAlias = Callable[[torch.Tensor, "TextTrainBatch"], torch.Tensor]
17+
18+
1519
@dataclass
1620
class TextTrainBatch:
1721
"""A batch of text training data for forward_backward.

0 commit comments

Comments
 (0)