Skip to content

Commit ad6852a

Browse files
committed
fix CI bugs
1 parent 510e291 commit ad6852a

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

nemo_rl/algorithms/loss_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class ClippedPGLossConfig(TypedDict):
4646
use_importance_sampling_correction: bool
4747
truncated_importance_sampling_ratio: float | None
4848
# Type of truncated importance sampling: "tis" (clamp max) or "icepop" (filter [min, max])
49-
truncated_importance_sampling_type: NotRequired[str]
49+
truncated_importance_sampling_type: NotRequired[str | None]
5050
# Lower bound for ICE-POP filtering (default 0.5)
51-
truncated_importance_sampling_ratio_min: NotRequired[float]
51+
truncated_importance_sampling_ratio_min: NotRequired[float | None]
5252
token_level_loss: bool
5353
# If True, apply the off-policy importance-sampling correction at the
5454
# sequence level (one weight per generated sample), as in GSPO.

nemo_rl/algorithms/reward_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RewardShapingConfig(TypedDict):
4747
# Stop properly penalty: scale factor for rewards of truncated responses (0-1).
4848
# When set to 0, truncated responses get zero reward.
4949
# When set to 1, no penalty is applied (default behavior).
50-
stop_properly_penalty_coef: NotRequired[float]
50+
stop_properly_penalty_coef: NotRequired[float | None]
5151

5252

5353
def apply_reward_shaping(

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ def generate(
543543
"logprobs": torch.zeros((0, 0), dtype=torch.float),
544544
"generation_lengths": torch.zeros(0, dtype=torch.long),
545545
"unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long),
546+
"truncated": torch.zeros(0, dtype=torch.bool),
546547
}
547548
)
548549

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,12 +727,18 @@ async def process_single_sample(sample_idx):
727727
device=input_ids_single_row.device,
728728
)
729729

730+
# Not truncated since no generation was attempted (length constraint)
731+
truncated_tensor = torch.tensor(
732+
[False], dtype=torch.bool, device=input_ids_single_row.device
733+
)
734+
730735
result_batch = BatchedDataDict[GenerationOutputSpec](
731736
{
732737
"output_ids": output_ids_single_item_batched,
733738
"logprobs": logprobs_single_item,
734739
"generation_lengths": generation_lengths_tensor,
735740
"unpadded_sequence_lengths": unpadded_sequence_lengths_tensor,
741+
"truncated": truncated_tensor,
736742
}
737743
)
738744

0 commit comments

Comments
 (0)