@@ -76,16 +76,29 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
7676 # TODO: Find another way to not unload after patches
7777 return super ().unpatch_model (device_to = device_to , unpatch_weights = unpatch_weights )
7878
79+
80+ def pin_weight_to_device (self , key ):
81+ op_key = key .rsplit ('.' , 1 )[0 ]
82+ if self .named_modules_to_munmap is not None and op_key in self .named_modules_to_munmap :
83+ # TODO: possible to OOM, find better way to detach
84+ self .named_modules_to_munmap [op_key ].to (self .load_device ).to (self .offload_device )
85+ del self .named_modules_to_munmap [op_key ]
86+ super ().pin_weight_to_device (key )
87+
7988 mmap_released = False
89+
8090 def load (self , * args , force_patch_weights = False , ** kwargs ):
91+ if not self .mmap_released :
92+ self .named_modules_to_munmap = dict (self .model .named_modules ())
93+
8194 # always call `patch_weight_to_device` even for lowvram
8295 super ().load (* args , force_patch_weights = True , ** kwargs )
8396
8497 # make sure nothing stays linked to mmap after first load
8598 if not self .mmap_released :
8699 linked = []
87100 if kwargs .get ("lowvram_model_memory" , 0 ) > 0 :
88- for n , m in self .model . named_modules ():
101+ for n , m in self .named_modules_to_munmap . items ():
89102 if hasattr (m , "weight" ):
90103 device = getattr (m .weight , "device" , None )
91104 if device == self .offload_device :
@@ -102,6 +115,7 @@ def load(self, *args, force_patch_weights=False, **kwargs):
102115 # TODO: possible to OOM, find better way to detach
103116 m .to (self .load_device ).to (self .offload_device )
104117 self .mmap_released = True
118+ self .named_modules_to_munmap = None
105119
106120 def clone (self , * args , ** kwargs ):
107121 src_cls = self .__class__
0 commit comments