Skip to content

Commit c627783

Browse files
committed
Add uuid to make sure eviction clears model
1 parent 075ab41 commit c627783

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

sdks/python/apache_beam/ml/inference/model_manager.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
usage and performance.
2525
"""
2626

27+
import uuid
2728
import time
2829
import threading
2930
import subprocess
@@ -249,6 +250,34 @@ def _solve(self):
249250
logger.error("Solver failed: %s", e)
250251

251252

253+
class TrackedModelProxy:
254+
def __init__(self, obj):
255+
object.__setattr__(self, "_wrapped_obj", obj)
256+
object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4()))
257+
258+
def __getattr__(self, name):
259+
return getattr(self._wrapped_obj, name)
260+
261+
def __setattr__(self, name, value):
262+
setattr(self._wrapped_obj, name, value)
263+
264+
def __call__(self, *args, **kwargs):
265+
return self._wrapped_obj(*args, **kwargs)
266+
267+
def __str__(self):
268+
return str(self._wrapped_obj)
269+
270+
def __repr__(self):
271+
return repr(self._wrapped_obj)
272+
273+
def __dir__(self):
274+
return dir(self._wrapped_obj)
275+
276+
def unsafe_hard_delete(self):
277+
if hasattr(self._wrapped_obj, "unsafe_hard_delete"):
278+
self._wrapped_obj.unsafe_hard_delete()
279+
280+
252281
class ModelManager:
253282
_lock = threading.Lock()
254283

@@ -546,12 +575,13 @@ def _perform_eviction(self, key, tag, instance, score):
546575
if key in self._idle_lru:
547576
del self._idle_lru[key]
548577

549-
if instance in self._models[tag]:
550-
self._models[tag].remove(instance)
551-
552-
if hasattr(instance, "unsafe_hard_delete"):
553-
instance.unsafe_hard_delete()
578+
target_id = instance._beam_tracking_id
579+
for i, inst in enumerate(self._models[tag]):
580+
if inst._beam_tracking_id == target_id:
581+
del self._models[tag][i]
582+
break
554583

584+
instance.unsafe_hard_delete()
555585
del instance
556586
gc.collect()
557587
torch.cuda.empty_cache()
@@ -565,7 +595,7 @@ def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost):
565595
with self._load_lock:
566596
logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
567597
isolation_baseline_snap, _, _ = self._monitor.get_stats()
568-
instance = loader_func()
598+
instance = TrackedModelProxy(loader_func())
569599
_, peak_during_load, _ = self._monitor.get_stats()
570600

571601
with self._cv:

0 commit comments

Comments
 (0)