Skip to content

Commit 540cc8c

Browse files
committed
Add timeout and force poll gpu usage to prevent race condition
1 parent e375b45 commit 540cc8c

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

sdks/python/apache_beam/ml/inference/model_manager.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ def get_stats(self) -> Tuple[float, float, float]:
103103
with self._lock:
104104
return self._current_usage, self._peak_usage, self._total_memory
105105

106+
def refresh(self):
107+
"""Forces an immediate poll of the GPU."""
108+
usage = self._get_nvidia_smi_used()
109+
now = time.time()
110+
with self._lock:
111+
self._current_usage = usage
112+
self._memory_history.append((now, usage))
113+
# Recalculate peak immediately
114+
while self._memory_history and (now - self._memory_history[0][0]
115+
> self._peak_window_seconds):
116+
self._memory_history.popleft()
117+
self._peak_usage = (
118+
max(m for _, m in self._memory_history)
119+
if self._memory_history else usage)
120+
106121
def _get_nvidia_smi_used(self) -> float:
107122
try:
108123
cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits"
@@ -360,7 +375,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
360375
if self._evict_to_make_space(limit, est_cost, requesting_tag=tag):
361376
continue
362377

363-
self._cv.wait()
378+
self._cv.wait(timeout=10.0)
364379

365380
finally:
366381
if self._wait_queue and self._wait_queue[0][2] is my_id:
@@ -492,6 +507,7 @@ def _perform_eviction(self, key, tag, instance, score):
492507
del instance
493508
gc.collect()
494509
torch.cuda.empty_cache()
510+
self._monitor.refresh()
495511
self._monitor.reset_peak()
496512

497513
def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost):
@@ -538,6 +554,8 @@ def _delete_all_models(self):
538554
self._active_counts.clear()
539555
gc.collect()
540556
torch.cuda.empty_cache()
557+
self._monitor.refresh()
558+
self._monitor.reset_peak()
541559

542560
def _force_reset(self):
543561
logger.warning("Force Reset Triggered")

sdks/python/apache_beam/ml/inference/model_manager_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def free(self, amount_mb):
8181
self.history.pop(0)
8282
self._peak = max(self.history)
8383

84+
def refresh(self):
85+
"""Simulates a refresh of the monitor stats (no-op for mock)."""
86+
pass
87+
8488

8589
class MockModel:
8690
def __init__(self, name, size, monitor):

0 commit comments

Comments
 (0)