Skip to content

Commit 723c7cb

Browse files
Mark ObozovMark Obozov
authored andcommitted
fix typehint
1 parent b50cb8a commit 723c7cb

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/forge/losses/grpo_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class SimpleGRPOLoss(nn.Module):
1616
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
1717
"""
1818

19-
def __init__(self, beta: float = 0.1) -> torch.Tensor | LossMetrics:
19+
def __init__(self, beta: float = 0.1) -> tuple[torch.Tensor, LossMetrics]:
2020
super().__init__()
2121
self.beta = beta
2222

src/forge/losses/reinforce_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self):
3030

3131
def forward(
3232
self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs
33-
) -> torch.Tensor | LossMetrics:
33+
) -> tuple[torch.Tensor, LossMetrics]:
3434
trainer_log_probs = selective_log_softmax(trainer_logits, target_ids)
3535
target_mask = target_mask.detach()
3636
target_weights = target_weights

0 commit comments

Comments
 (0)