Skip to content

Commit 2bcfff7

Browse files
committed
Update
[ghstack-poisoned]
1 parent 088fb20 commit 2bcfff7

File tree

2 files changed

+80
-82
lines changed

2 files changed

+80
-82
lines changed

test/llm/test_llm_envs.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,65 @@ def test_ifeval(self):
420420
# env.check_env_specs()
421421

422422

423+
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
424+
class TestIFEvalRewardAggregator:
425+
"""Unit tests for the simplified IFEval reward aggregator."""
426+
427+
def test_perfect_score_with_format(self):
428+
from torchrl.envs.llm.reward.ifeval._scorer import IFEvalScoreData, IfEvalScorer
429+
430+
scorer = IfEvalScorer()
431+
score = IFEvalScoreData(
432+
prompt_level_strict_acc=torch.tensor([True]),
433+
inst_level_strict_acc=torch.tensor([True]),
434+
prompt_level_loose_acc=torch.tensor([True]),
435+
inst_level_loose_acc=torch.tensor([True]),
436+
batch_size=(),
437+
)
438+
reward = scorer.default_reward_aggregator(
439+
score,
440+
think_blocks=["reasoning"],
441+
answer_blocks=["answer"],
442+
)
443+
# format_score = 1.0 + format_bonus = 0.1 + 0.05 = 1.15
444+
assert reward.item() == pytest.approx(1.15, abs=0.01)
445+
446+
def test_zero_score_no_answer(self):
447+
from torchrl.envs.llm.reward.ifeval._scorer import IFEvalScoreData, IfEvalScorer
448+
449+
scorer = IfEvalScorer()
450+
score = IFEvalScoreData(
451+
prompt_level_strict_acc=torch.tensor([False]),
452+
inst_level_strict_acc=torch.tensor([False]),
453+
prompt_level_loose_acc=torch.tensor([False]),
454+
inst_level_loose_acc=torch.tensor([False]),
455+
batch_size=(),
456+
)
457+
reward = scorer.default_reward_aggregator(
458+
score, think_blocks=[], answer_blocks=[]
459+
)
460+
# No format bonus, all metrics zero
461+
assert reward.item() == pytest.approx(0.0, abs=0.01)
462+
463+
def test_reward_range_bounded(self):
464+
from torchrl.envs.llm.reward.ifeval._scorer import IFEvalScoreData, IfEvalScorer
465+
466+
scorer = IfEvalScorer()
467+
score = IFEvalScoreData(
468+
prompt_level_strict_acc=torch.tensor([True]),
469+
inst_level_strict_acc=torch.tensor([True]),
470+
prompt_level_loose_acc=torch.tensor([True]),
471+
inst_level_loose_acc=torch.tensor([True]),
472+
batch_size=(),
473+
)
474+
reward = scorer.default_reward_aggregator(
475+
score,
476+
think_blocks=["t"],
477+
answer_blocks=["a"],
478+
)
479+
assert 0.0 <= reward.item() <= 1.2
480+
481+
423482
class TestTools:
424483
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
425484
def test_python_interpreter_single_batch(self):

torchrl/envs/llm/reward/ifeval/_scorer.py

Lines changed: 21 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class IfEvalScorer(Transform):
165165
it must take as input an :class:`~torchrl.envs.llm.IFEvalScoreData` instance, and optionally `think_blocks`, `answer_blocks` and `complete` keyword arguments
166166
containing the list of think and answer blocks, respectively.
167167
It must return a tensor with shape identical to the env batch-size with an additional trailing singleton dimension.
168-
Defaults to `True`. The default aggregator is a simple sum over the fields of :class:`~torchrl.envs.llm.IFEvalScoreData`.
168+
Defaults to `True`. The default aggregator computes a weighted average of the IFEval metrics plus a small format bonus (reward range ~[0, 1.15]).
169169
format_weights (list[float], optional): The weights for the format fields (`prompt_level_strict_acc`, `inst_level_strict_acc`,
170170
`prompt_level_loose_acc`, `inst_level_loose_acc`, in that order). Defaults to `[0.4, 0.3, 0.2, 0.1]`.
171171
This is only used if `aggregate_reward` is `True` and the default aggregator is used.
@@ -230,65 +230,44 @@ def default_reward_aggregator(
230230
answer_blocks: list[str] | None = None,
231231
complete: bool | torch.Tensor | None = None,
232232
) -> torch.Tensor:
233-
r"""Improved reward aggregation function with tiered multiplicative scoring.
233+
"""Reward aggregation based on weighted IFEval metrics plus a small format bonus.
234234
235235
Args:
236236
score (IFEvalScoreData): The score data.
237237
think_blocks (list[str], optional): The list of think blocks.
238238
answer_blocks (list[str], optional): The list of answer blocks.
239-
complete (bool, optional): Whether the response is complete (ends with a eos token).
239+
complete (bool, optional): Whether the response is complete (ends with an eos token).
240240
241-
The reward uses a tiered multiplicative system:
241+
The reward is computed as:
242242
243-
1. Critical failure check: No answer blocks = 0 reward
244-
2. Base format score (0-1): Weighted average of format metrics
245-
3. Structure multiplier (0.1-1.0): Penalties for missing/multiple blocks
246-
4. Quality bonus (0-0.5): Rewards for high quality and completion
247-
5. Task complexity scaling: More requirements = higher potential rewards
243+
reward = weighted_avg(strict/loose metrics) + format_bonus
248244
249-
The final formula is:
250-
reward = (format_score + quality_bonus) * structure_multiplier * complexity_scale
245+
where ``format_bonus`` gives a small additive reward (up to 0.15) for
246+
well-structured responses with proper ``<think>`` / ``<answer>`` tags.
251247
252-
This provides better learning signals by:
253-
- Requiring critical elements (answer tags) for meaningful rewards
254-
- Using multiplicative scaling to reward doing everything well
255-
- Scaling rewards based on task complexity
256-
- Providing clear failure modes and success incentives
257-
258-
Reward range: 0.0 to ~1.5-2.7 depending on task complexity (more instructions = higher max reward).
248+
Reward range: approximately [0.0, 1.15].
259249
"""
260250
default_dtype = torch.get_default_dtype()
261251
score = score.to(default_dtype)
262252

263-
# Critical failure check - no answer = no reward
264-
if not answer_blocks:
265-
return torch.zeros(
266-
score.batch_size + (1,), device=score.device, dtype=default_dtype
267-
)
253+
zero = torch.zeros(
254+
score.batch_size + (1,), device=score.device, dtype=default_dtype
255+
)
268256

269-
# Base format score calculation (0-1)
270257
format_components = torch.stack(
271258
[
272259
score.prompt_level_strict_acc.sum(-1, keepdim=True)
273260
if score.prompt_level_strict_acc is not None
274-
else torch.zeros(
275-
score.batch_size + (1,), device=score.device, dtype=default_dtype
276-
), # Single value
261+
else zero,
277262
score.inst_level_strict_acc.mean(-1, keepdim=True)
278263
if score.inst_level_strict_acc is not None
279-
else torch.zeros(
280-
score.batch_size + (1,), device=score.device, dtype=default_dtype
281-
), # Average across instructions
264+
else zero,
282265
score.prompt_level_loose_acc.sum(-1, keepdim=True)
283266
if score.prompt_level_loose_acc is not None
284-
else torch.zeros(
285-
score.batch_size + (1,), device=score.device, dtype=default_dtype
286-
), # Single value
267+
else zero,
287268
score.inst_level_loose_acc.mean(-1, keepdim=True)
288269
if score.inst_level_loose_acc is not None
289-
else torch.zeros(
290-
score.batch_size + (1,), device=score.device, dtype=default_dtype
291-
), # Average across instructions
270+
else zero,
292271
],
293272
-1,
294273
)
@@ -299,53 +278,13 @@ def default_reward_aggregator(
299278
)
300279
format_score = (format_components * weights).sum(dim=-1, keepdim=True)
301280

302-
# Structure multiplier (0.1-1.0)
303-
structure_multiplier = 1.0
304-
305-
# Heavy penalty for missing think blocks (but not zero)
306-
if not think_blocks:
307-
structure_multiplier *= 0.3
308-
elif len(think_blocks) > 1:
309-
structure_multiplier *= 0.7 # Penalty for multiple think blocks
310-
311-
# Penalty for multiple answer blocks
312-
if len(answer_blocks) > 1:
313-
structure_multiplier *= 0.7
314-
315-
# Quality bonus (0-0.5)
316-
quality_bonus = torch.zeros_like(format_score)
317-
318-
# Bonus for high quality responses
319-
if format_score > 0.8:
320-
quality_bonus += 0.3
321-
322-
# Completion bonus
323-
if complete is not None:
324-
if isinstance(complete, torch.Tensor):
325-
completion_bonus = complete.to(default_dtype) * 0.2
326-
else:
327-
completion_bonus = float(complete) * 0.2
328-
quality_bonus += completion_bonus
329-
330-
# Task complexity scaling based on number of instructions
331-
# More instructions = higher potential rewards
332-
if (
333-
score.inst_level_strict_acc is not None
334-
and score.inst_level_strict_acc.numel() > 0
335-
):
336-
num_instructions = score.inst_level_strict_acc.shape[-1]
337-
else:
338-
num_instructions = 1
339-
complexity_scale = (
340-
1.0 + (num_instructions - 1) * 0.2
341-
) # 1.0 for 1 instruction, 1.2 for 2, etc.
342-
343-
# Final reward: (format + quality) * structure_multiplier * complexity_scale
344-
final_reward = (
345-
(format_score + quality_bonus) * structure_multiplier * complexity_scale
346-
)
347-
final_reward = final_reward.to(default_dtype)
281+
format_bonus = 0.0
282+
if answer_blocks and len(answer_blocks) == 1:
283+
format_bonus += 0.1
284+
if think_blocks and len(think_blocks) == 1:
285+
format_bonus += 0.05
348286

287+
final_reward = (format_score + format_bonus).to(default_dtype)
349288
return final_reward
350289

351290
def _step(

0 commit comments

Comments
 (0)