Skip to content

Commit bf3dbf0

Browse files
committed
simplify sampling, add todo
1 parent 22891bf commit bf3dbf0

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/forge/data/rewards.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import random
78
import re
89

910

@@ -111,7 +112,7 @@ class LanguageReward:
111112
match_reward: Reward when detected language matches target (default: 1.0)
112113
no_match_reward: Reward when language doesn't match (default: 0.0)
113114
tag: Tag name to use (default "思考" for multilingual, can use "think", etc.)
114-
debug: If True, print debug samples showing model outputs and detected language
115+
debug: If True, print debug samples showing model outputs and detected language for every sample
115116
debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls)
116117
117118
Note: Requires langid to be installed. Install with: pip install langid
@@ -164,13 +165,9 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
164165
Returns:
165166
match_reward if detected language matches target, no_match_reward otherwise
166167
"""
167-
# Increment counter for sampling
168-
self._debug_counter += 1
169-
should_debug = (
170-
self.debug
171-
and self.debug_sample_rate > 0
172-
and (self._debug_counter % int(1 / self.debug_sample_rate)) == 0
173-
)
168+
169+
# TODO: refactor pending https://github.com/meta-pytorch/torchforge/issues/187
170+
should_debug = debug or (random.random() < self.debug_sample_rate)
174171

175172
if not response:
176173
if should_debug:

0 commit comments

Comments
 (0)