@@ -94,17 +94,15 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio
9494 if current_task :
9595 await current_task
9696
97- async def _flexible_verify_model_version (self , step , staleness_limit ):
97+ async def _flexible_verify_model_version (self , step , max_staleness ):
9898 _ , metrics , _ = await self .sample_strategy .sample (step = step )
9999 self .assertGreaterEqual (
100100 metrics ["sample/model_version/min" ],
101- step - staleness_limit ,
101+ step - max_staleness ,
102102 f"Min model version mismatch at step { step } " ,
103103 )
104104
105- async def _flexible_verify_sampling_model_versions (
106- self , exps_list , check_steps , staleness_limit
107- ):
105+ async def _flexible_verify_sampling_model_versions (self , exps_list , check_steps , max_staleness ):
108106 self ._init_buffer_writer_and_sample_strategy ()
109107
110108 # Write experiences to buffer, while sample and validate model versions
@@ -115,7 +113,7 @@ async def _flexible_verify_sampling_model_versions(
115113 if current_task :
116114 await current_task
117115 current_task = asyncio .create_task (
118- self ._flexible_verify_model_version (step , staleness_limit )
116+ self ._flexible_verify_model_version (step , max_staleness )
119117 )
120118 await asyncio .sleep (0.1 )
121119
@@ -146,9 +144,9 @@ async def test_default_queue_default_sample_strategy(self):
146144 await self ._verify_sampling_model_versions (exps_list , expected_model_versions_map )
147145
148146 async def test_default_queue_staleness_control_sample_strategy (self ):
149- staleness_limit = 3
147+ max_staleness = 3
150148 self .config .algorithm .sample_strategy = "staleness_control"
151- self .config .algorithm .sample_strategy_args = {"staleness_limit " : staleness_limit }
149+ self .config .algorithm .sample_strategy_args = {"max_staleness " : max_staleness }
152150 self .config .buffer .trainer_input .experience_buffer = ExperienceBufferConfig (
153151 name = "default_queue_staleness_control" ,
154152 storage_type = StorageType .QUEUE .value ,
@@ -161,15 +159,15 @@ async def test_default_queue_staleness_control_sample_strategy(self):
161159 steps = self ._default_steps ()
162160 expected_model_versions_map = {}
163161 for step in steps :
164- predict_version = max (step - staleness_limit , 0 )
162+ predict_version = max (step - max_staleness , 0 )
165163 expected_model_versions_map [step ] = [
166164 predict_version + i // self .exp_write_batch_size
167165 for i in range (self .config .buffer .train_batch_size )
168166 ]
169167
170168 await self ._verify_sampling_model_versions (exps_list , expected_model_versions_map )
171169
172- def _simulate_priority_queue (self , steps , staleness_limit = float ("inf" )):
170+ def _simulate_priority_queue (self , steps , max_staleness = float ("inf" )):
173171 expected_model_versions_map = {}
174172 buffer = deque ()
175173 exp_pool = deque ()
@@ -187,7 +185,7 @@ def _simulate_priority_queue(self, steps, staleness_limit=float("inf")):
187185 exp_pool .extend (buffer .pop ())
188186 while len (exp_pool ) > 0 and len (batch_versions ) < train_batch_size :
189187 exp_version = exp_pool .popleft ()
190- if exp_version < step - staleness_limit :
188+ if exp_version < step - max_staleness :
191189 continue
192190 batch_versions .append (exp_version )
193191 if len (batch_versions ) >= train_batch_size :
@@ -214,9 +212,9 @@ async def test_priority_queue_default_sample_strategy(self):
214212 await self ._verify_sampling_model_versions (exps_list , expected_model_versions_map )
215213
216214 async def test_priority_queue_staleness_control_sample_strategy (self ):
217- staleness_limit = 2
215+ max_staleness = 2
218216 self .config .algorithm .sample_strategy = "staleness_control"
219- self .config .algorithm .sample_strategy_args = {"staleness_limit " : staleness_limit }
217+ self .config .algorithm .sample_strategy_args = {"max_staleness " : max_staleness }
220218 self .config .buffer .trainer_input .experience_buffer = ExperienceBufferConfig (
221219 name = "priority_queue_staleness_control" ,
222220 storage_type = StorageType .QUEUE .value ,
@@ -227,14 +225,14 @@ async def test_priority_queue_staleness_control_sample_strategy(self):
227225 # init testing data
228226 exps_list = self ._default_exp_list ()
229227 steps = self ._default_steps ()
230- expected_model_versions_map = self ._simulate_priority_queue (steps , staleness_limit )
228+ expected_model_versions_map = self ._simulate_priority_queue (steps , max_staleness )
231229
232230 await self ._verify_sampling_model_versions (exps_list , expected_model_versions_map )
233231
234232 async def test_sql_staleness_control_sample_strategy (self ):
235- staleness_limit = 2
233+ max_staleness = 2
236234 self .config .algorithm .sample_strategy = "staleness_control"
237- self .config .algorithm .sample_strategy_args = {"staleness_limit " : staleness_limit }
235+ self .config .algorithm .sample_strategy_args = {"max_staleness " : max_staleness }
238236 self .config .buffer .trainer_input .experience_buffer = ExperienceBufferConfig (
239237 name = "sql_staleness_control" ,
240238 storage_type = StorageType .SQL .value ,
@@ -245,7 +243,7 @@ async def test_sql_staleness_control_sample_strategy(self):
245243 exps_list = self ._default_exp_list ()
246244 steps = self ._default_steps ()
247245
248- await self ._flexible_verify_sampling_model_versions (exps_list , steps , staleness_limit )
246+ await self ._flexible_verify_sampling_model_versions (exps_list , steps , max_staleness )
249247
250248 def tearDown (self ):
251249 asyncio .run (self .buffer_writer .release ())
0 commit comments