11import asyncio
2- from collections import deque
32import shutil
3+ from collections import deque
4+
45import torch
6+
57from tests .tools import RayUnittestBaseAysnc , get_template_config
6- from trinity .algorithm .sample_strategy .sample_strategy import SAMPLE_STRATEGY , SampleStrategy
8+ from trinity .algorithm .sample_strategy .sample_strategy import (
9+ SAMPLE_STRATEGY ,
10+ SampleStrategy ,
11+ )
712from trinity .buffer .buffer import get_buffer_writer
8- from trinity .common .config import ExperienceBufferConfig
9- from trinity .common .constants import StorageType
1013from trinity .common .experience import Experience
1114
1215
@@ -29,16 +32,30 @@ def _default_exp_list(self):
2932 ]
3033 for i in range (self .num_steps )
3134 ]
32-
35+
3336 def _default_steps (self ):
3437 return [0 , 5 , 10 , 15 ]
3538
3639 async def _verify_model_version (self , step , expected_versions ):
3740 batch , metrics , _ = await self .sample_strategy .sample (step = step )
38- self .assertEqual (batch .rewards .tolist (), expected_versions , f"Model versions mismatch at step { step } " )
39- self .assertEqual (metrics ['sample/model_version/min' ], min (expected_versions ), f"Min model version mismatch at step { step } " )
40- self .assertEqual (metrics ['sample/model_version/max' ], max (expected_versions ), f"Max model version mismatch at step { step } " )
41- self .assertEqual (metrics ['sample/model_version/mean' ], sum (expected_versions ) / len (expected_versions ), f"Mean model version mismatch at step { step } " )
41+ self .assertEqual (
42+ batch .rewards .tolist (), expected_versions , f"Model versions mismatch at step { step } "
43+ )
44+ self .assertEqual (
45+ metrics ["sample/model_version/min" ],
46+ min (expected_versions ),
47+ f"Min model version mismatch at step { step } " ,
48+ )
49+ self .assertEqual (
50+ metrics ["sample/model_version/max" ],
51+ max (expected_versions ),
52+ f"Max model version mismatch at step { step } " ,
53+ )
54+ self .assertEqual (
55+ metrics ["sample/model_version/mean" ],
56+ sum (expected_versions ) / len (expected_versions ),
57+ f"Mean model version mismatch at step { step } " ,
58+ )
4259
4360 async def _verify_sampling_model_versions (self , exps_list , expected_model_versions_map ):
4461 # Initialize buffer writer and sample strategy
@@ -51,7 +68,7 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio
5168 buffer_config = self .config .buffer ,
5269 ** self .config .algorithm .sample_strategy_args ,
5370 )
54-
71+
5572 # Write experiences to buffer, while sample and validate model versions
5673 current_task = None
5774 for step , exps in enumerate (exps_list ):
@@ -106,7 +123,7 @@ async def test_default_queue_staleness_control_sample_strategy(self):
106123
107124 await self ._verify_sampling_model_versions (exps_list , expected_model_versions_map )
108125
109- def _simulate_priority_queue (self , steps , staleness_limit = float (' inf' )):
126+ def _simulate_priority_queue (self , steps , staleness_limit = float (" inf" )):
110127 expected_model_versions_map = {}
111128 buffer = deque ()
112129 exp_pool = deque ()
0 commit comments