Skip to content

Commit 99541ff

Browse files
authored
Fix flakiness in model manager and multiprocessshared tests (#37529)
* Fix flakes * Don't use the same tag twice
1 parent dd92b8f commit 99541ff

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

sdks/python/apache_beam/ml/inference/model_manager_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,22 +333,21 @@ def test_single_model_convergence_with_fluctuations(self):
333333
"""
334334
model_name = "fluctuating_model"
335335
model_cost = 3000.0
336-
load_cost = 2500.0
337336
# Fix random seed for reproducibility
338337
random.seed(42)
339338

340339
def loader():
341-
self.mock_monitor.allocate(load_cost)
340+
self.mock_monitor.allocate(model_cost)
342341
return model_name
343342

344343
model = self.manager.acquire_model(model_name, loader)
345344
self.manager.release_model(model_name, model)
346345
initial_est = self.manager._estimator.get_estimate(model_name)
347-
self.assertEqual(initial_est, load_cost)
346+
self.assertEqual(initial_est, model_cost)
348347

349348
def run_inference():
350349
model = self.manager.acquire_model(model_name, loader)
351-
noise = model_cost - load_cost + random.uniform(-300.0, 300.0)
350+
noise = random.uniform(-300.0, 300.0)
352351
self.mock_monitor.allocate(noise)
353352
time.sleep(0.1)
354353
self.mock_monitor.free(noise)

sdks/python/apache_beam/utils/multi_process_shared_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def setUp(self):
289289
for tag in ['basic',
290290
'main',
291291
'to_delete',
292+
'to_keep',
292293
'mix1',
293294
'mix2',
294295
'test_process_exit',
@@ -310,7 +311,7 @@ def tearDown(self):
310311

311312
def test_call(self):
312313
shared = multi_process_shared.MultiProcessShared(
313-
Counter, tag='basic', always_proxy=True, spawn_process=True).acquire()
314+
Counter, tag='main', always_proxy=True, spawn_process=True).acquire()
314315
self.assertEqual(shared.get(), 0)
315316
self.assertEqual(shared.increment(), 1)
316317
self.assertEqual(shared.increment(10), 11)
@@ -323,7 +324,8 @@ def test_unsafe_hard_delete_autoproxywrapper(self):
323324
shared2 = multi_process_shared.MultiProcessShared(
324325
Counter, tag='to_delete', always_proxy=True, spawn_process=True)
325326
counter3 = multi_process_shared.MultiProcessShared(
326-
Counter, tag='basic', always_proxy=True, spawn_process=True).acquire()
327+
Counter, tag='to_keep', always_proxy=True,
328+
spawn_process=True).acquire()
327329

328330
counter1 = shared1.acquire()
329331
counter2 = shared2.acquire()

0 commit comments

Comments
 (0)