@@ -79,13 +79,14 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
7979
8080 def pin_weight_to_device (self , key ):
8181 op_key = key .rsplit ('.' , 1 )[0 ]
82- if self . named_modules_to_munmap is not None and op_key in self .named_modules_to_munmap :
82+ if not self . mmap_released and op_key in self .named_modules_to_munmap :
8383 # TODO: possible to OOM, find better way to detach
8484 self .named_modules_to_munmap [op_key ].to (self .load_device ).to (self .offload_device )
8585 del self .named_modules_to_munmap [op_key ]
8686 super ().pin_weight_to_device (key )
8787
8888 mmap_released = False
89+ named_modules_to_munmap = {}
8990
9091 def load (self , * args , force_patch_weights = False , ** kwargs ):
9192 if not self .mmap_released :
@@ -115,7 +116,7 @@ def load(self, *args, force_patch_weights=False, **kwargs):
115116 # TODO: possible to OOM, find better way to detach
116117 m .to (self .load_device ).to (self .offload_device )
117118 self .mmap_released = True
118- self .named_modules_to_munmap = None
119+ self .named_modules_to_munmap = {}
119120
120121 def clone (self , * args , ** kwargs ):
121122 src_cls = self .__class__
@@ -125,6 +126,7 @@ def clone(self, *args, **kwargs):
125126 self .__class__ = src_cls
126127 # GGUF specific clone values below
127128 n .patch_on_device = getattr (self , "patch_on_device" , False )
129+ n .mmap_released = getattr (self , "mmap_released" , False )
128130 if src_cls != GGUFModelPatcher :
129131 n .size = 0 # force recalc
130132 return n
0 commit comments