Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trinity/buffer/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"reward_shaping_mapper": "trinity.buffer.operators.mappers.reward_shaping_mapper.RewardShapingMapper",
"pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator",
"data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator",
"low_quality_experience_filter": "trinity.buffer.operators.filters.low_quality_filter.LowQualityExperienceFilter",

},
)

Expand Down
19 changes: 19 additions & 0 deletions trinity/buffer/operators/filters/low_quality_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Tuple
import math

from trinity.buffer.operators.experience_operator import ExperienceOperator
from trinity.common.experience import Experience


class LowQualityExperienceFilter(ExperienceOperator):
def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
kept = []
for e in exps:
r = e.reward
if r is None:
continue
if isinstance(r, float) and math.isnan(r):
continue
kept.append(e)

return kept, {"filtered_count": len(exps) - len(kept)}