2424usage and performance.
2525"""
2626
27+ import uuid
2728import time
2829import threading
2930import subprocess
@@ -249,6 +250,34 @@ def _solve(self):
249250 logger .error ("Solver failed: %s" , e )
250251
251252
253+ class TrackedModelProxy :
254+ def __init__ (self , obj ):
255+ object .__setattr__ (self , "_wrapped_obj" , obj )
256+ object .__setattr__ (self , "_beam_tracking_id" , str (uuid .uuid4 ()))
257+
258+ def __getattr__ (self , name ):
259+ return getattr (self ._wrapped_obj , name )
260+
261+ def __setattr__ (self , name , value ):
262+ setattr (self ._wrapped_obj , name , value )
263+
264+ def __call__ (self , * args , ** kwargs ):
265+ return self ._wrapped_obj (* args , ** kwargs )
266+
267+ def __str__ (self ):
268+ return str (self ._wrapped_obj )
269+
270+ def __repr__ (self ):
271+ return repr (self ._wrapped_obj )
272+
273+ def __dir__ (self ):
274+ return dir (self ._wrapped_obj )
275+
276+ def unsafe_hard_delete (self ):
277+ if hasattr (self ._wrapped_obj , "unsafe_hard_delete" ):
278+ self ._wrapped_obj .unsafe_hard_delete ()
279+
280+
252281class ModelManager :
253282 _lock = threading .Lock ()
254283
@@ -546,12 +575,13 @@ def _perform_eviction(self, key, tag, instance, score):
546575 if key in self ._idle_lru :
547576 del self ._idle_lru [key ]
548577
549- if instance in self . _models [ tag ]:
550- self ._models [tag ]. remove ( instance )
551-
552- if hasattr ( instance , "unsafe_hard_delete" ):
553- instance . unsafe_hard_delete ()
578+ target_id = instance . _beam_tracking_id
579+ for i , inst in enumerate ( self ._models [tag ]):
580+ if inst . _beam_tracking_id == target_id :
581+ del self . _models [ tag ][ i ]
582+ break
554583
584+ instance .unsafe_hard_delete ()
555585 del instance
556586 gc .collect ()
557587 torch .cuda .empty_cache ()
@@ -565,7 +595,7 @@ def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost):
565595 with self ._load_lock :
566596 logger .info ("Loading Model: %s (Unknown: %s)" , tag , is_unknown )
567597 isolation_baseline_snap , _ , _ = self ._monitor .get_stats ()
568- instance = loader_func ()
598+ instance = TrackedModelProxy ( loader_func () )
569599 _ , peak_during_load , _ = self ._monitor .get_stats ()
570600
571601 with self ._cv :
0 commit comments