Skip to content

Commit 9a9b79f

Browse files
authored
Merge branch 'meta-pytorch:main' into openenv
2 parents daa57c5 + 448c18a commit 9a9b79f

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

src/forge/actors/replay_buffer.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def age_evict(
3333
"""Buffer eviction policy, remove old or over-sampled entries"""
3434
indices = []
3535
for i, entry in enumerate(buffer):
36-
if max_age and policy_version - entry.data.policy_version > max_age:
36+
if max_age is not None and policy_version - entry.data.policy_version > max_age:
3737
continue
38-
if max_samples and entry.sample_count >= max_samples:
38+
if max_samples is not None and entry.sample_count >= max_samples:
3939
continue
4040
indices.append(i)
4141
return indices
@@ -120,6 +120,27 @@ async def sample(
120120
entry.sample_count += 1
121121
sampled_episodes.append(entry.data)
122122

123+
# Calculate and record policy age metrics for sampled episodes
124+
sampled_policy_ages = [
125+
curr_policy_version - ep.policy_version for ep in sampled_episodes
126+
]
127+
if sampled_policy_ages:
128+
record_metric(
129+
"buffer/sample/avg_sampled_policy_age",
130+
sum(sampled_policy_ages) / len(sampled_policy_ages),
131+
Reduce.MEAN,
132+
)
133+
record_metric(
134+
"buffer/sample/max_sampled_policy_age",
135+
max(sampled_policy_ages),
136+
Reduce.MAX,
137+
)
138+
record_metric(
139+
"buffer/sample/min_sampled_policy_age",
140+
min(sampled_policy_ages),
141+
Reduce.MIN,
142+
)
143+
123144
# Reshape into (dp_size, bsz, ...)
124145
reshaped_episodes = [
125146
sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]
@@ -149,22 +170,6 @@ def _evict(self, curr_policy_version):
149170
)
150171
self.buffer = deque(self._collect(indices))
151172

152-
# Record evict metrics
153-
policy_age = [
154-
curr_policy_version - ep.data.policy_version for ep in self.buffer
155-
]
156-
if policy_age:
157-
record_metric(
158-
"buffer/evict/avg_policy_age",
159-
sum(policy_age) / len(policy_age),
160-
Reduce.MEAN,
161-
)
162-
record_metric(
163-
"buffer/evict/max_policy_age",
164-
max(policy_age),
165-
Reduce.MAX,
166-
)
167-
168173
evicted_count = buffer_len_before_evict - len(self.buffer)
169174
record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM)
170175

0 commit comments

Comments
 (0)