Skip to content

Commit 08c5144

Browse files
committed
pre commit fix
1 parent ba74224 commit 08c5144

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

tests/buffer/sample_strategy_test.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import asyncio
2-
from collections import deque
32
import shutil
3+
from collections import deque
4+
45
import torch
6+
57
from tests.tools import RayUnittestBaseAysnc, get_template_config
6-
from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY, SampleStrategy
8+
from trinity.algorithm.sample_strategy.sample_strategy import (
9+
SAMPLE_STRATEGY,
10+
SampleStrategy,
11+
)
712
from trinity.buffer.buffer import get_buffer_writer
8-
from trinity.common.config import ExperienceBufferConfig
9-
from trinity.common.constants import StorageType
1013
from trinity.common.experience import Experience
1114

1215

@@ -29,16 +32,30 @@ def _default_exp_list(self):
2932
]
3033
for i in range(self.num_steps)
3134
]
32-
35+
3336
def _default_steps(self):
3437
return [0, 5, 10, 15]
3538

3639
async def _verify_model_version(self, step, expected_versions):
3740
batch, metrics, _ = await self.sample_strategy.sample(step=step)
38-
self.assertEqual(batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}")
39-
self.assertEqual(metrics['sample/model_version/min'], min(expected_versions), f"Min model version mismatch at step {step}")
40-
self.assertEqual(metrics['sample/model_version/max'], max(expected_versions), f"Max model version mismatch at step {step}")
41-
self.assertEqual(metrics['sample/model_version/mean'], sum(expected_versions) / len(expected_versions), f"Mean model version mismatch at step {step}")
41+
self.assertEqual(
42+
batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}"
43+
)
44+
self.assertEqual(
45+
metrics["sample/model_version/min"],
46+
min(expected_versions),
47+
f"Min model version mismatch at step {step}",
48+
)
49+
self.assertEqual(
50+
metrics["sample/model_version/max"],
51+
max(expected_versions),
52+
f"Max model version mismatch at step {step}",
53+
)
54+
self.assertEqual(
55+
metrics["sample/model_version/mean"],
56+
sum(expected_versions) / len(expected_versions),
57+
f"Mean model version mismatch at step {step}",
58+
)
4259

4360
async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map):
4461
# Initialize buffer writer and sample strategy
@@ -51,7 +68,7 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio
5168
buffer_config=self.config.buffer,
5269
**self.config.algorithm.sample_strategy_args,
5370
)
54-
71+
5572
# Write experiences to buffer, while sample and validate model versions
5673
current_task = None
5774
for step, exps in enumerate(exps_list):
@@ -106,7 +123,7 @@ async def test_default_queue_staleness_control_sample_strategy(self):
106123

107124
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
108125

109-
def _simulate_priority_queue(self, steps, staleness_limit = float('inf')):
126+
def _simulate_priority_queue(self, steps, staleness_limit=float("inf")):
110127
expected_model_versions_map = {}
111128
buffer = deque()
112129
exp_pool = deque()

0 commit comments

Comments
 (0)