Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trinity/buffer/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"reward_shaping_mapper": "trinity.buffer.operators.mappers.reward_shaping_mapper.RewardShapingMapper",
"pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator",
"data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator",
"low_quality_experience_filter": "trinity.buffer.operators.filters.low_quality_filter.LowQualityExperienceFilter",

},
)

Expand Down
12 changes: 12 additions & 0 deletions trinity/buffer/operators/filters/low_quality_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import List, Tuple
import math

from trinity.buffer.operators.experience_operator import ExperienceOperator
from trinity.common.experience import Experience


class LowQualityExperienceFilter(ExperienceOperator):
def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
kept = [e for e in exps if e.reward is not None and e.reward == e.reward]

return kept, {"filtered_count": len(exps) - len(kept)}