diff --git a/trinity/buffer/operators/__init__.py b/trinity/buffer/operators/__init__.py index 258d4d76b5..e83b7d05ee 100644 --- a/trinity/buffer/operators/__init__.py +++ b/trinity/buffer/operators/__init__.py @@ -9,6 +9,7 @@ "reward_shaping_mapper": "trinity.buffer.operators.mappers.reward_shaping_mapper.RewardShapingMapper", "pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator", "data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator", + "invalid_reward_filter": "trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter", }, ) diff --git a/trinity/buffer/operators/filters/reward_filter.py b/trinity/buffer/operators/filters/reward_filter.py index dc5bd92e7e..07126db12f 100644 --- a/trinity/buffer/operators/filters/reward_filter.py +++ b/trinity/buffer/operators/filters/reward_filter.py @@ -50,3 +50,19 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]: final_count = len(result_exps) metrics["filtered_count"] = original_count - final_count return result_exps, metrics + + +class InvalidRewardFilter(ExperienceOperator): + """ + Filters out experiences with invalid reward values. + + Note: This operator assumes that rewards are already computed and stored in the + Experience object.Any experience with a missing (`None`) or invalid (`NaN`) + reward is removed to prevent low-quality data from entering the training + pipeline. + """ + + def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]: + kept = [e for e in exps if e.reward is not None and e.reward == e.reward] + + return kept, {"filtered_count": len(exps) - len(kept)}