Skip to content

Commit 434715d

Browse files
Add low-quality experience filter operator #470 (#473)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 15a6d2a commit 434715d

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

trinity/buffer/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"reward_shaping_mapper": "trinity.buffer.operators.mappers.reward_shaping_mapper.RewardShapingMapper",
1010
"pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator",
1111
"data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator",
12+
"invalid_reward_filter": "trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter",
1213
},
1314
)
1415

trinity/buffer/operators/filters/reward_filter.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,19 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
5050
final_count = len(result_exps)
5151
metrics["filtered_count"] = original_count - final_count
5252
return result_exps, metrics
53+
54+
55+
class InvalidRewardFilter(ExperienceOperator):
56+
"""
57+
Filters out experiences with invalid reward values.
58+
59+
Note: This operator assumes that rewards are already computed and stored in the
60+
Experience object.Any experience with a missing (`None`) or invalid (`NaN`)
61+
reward is removed to prevent low-quality data from entering the training
62+
pipeline.
63+
"""
64+
65+
def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
66+
kept = [e for e in exps if e.reward is not None and e.reward == e.reward]
67+
68+
return kept, {"filtered_count": len(exps) - len(kept)}

0 commit comments

Comments
 (0)