Skip to content

Commit 100c06c

Browse files
authored
nodes: mmap detach weights before they are pinned (#355)
Comfy core recently introduced a feature where weights may be pinned when loading, particularly for the case of offloading. Intercept this, and immediately detached each weight before the pinning. This avoids a crash that at least some users are experiencing. Use a little dict on the modules to keep track of whats already done, and when the catch-all detacher loop comes through, use this dict (which has already done modules removed) as the iterator basis.
1 parent be2a083 commit 100c06c

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

nodes.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,29 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
7676
# TODO: Find another way to not unload after patches
7777
return super().unpatch_model(device_to=device_to, unpatch_weights=unpatch_weights)
7878

79+
80+
def pin_weight_to_device(self, key):
81+
op_key = key.rsplit('.', 1)[0]
82+
if self.named_modules_to_munmap is not None and op_key in self.named_modules_to_munmap:
83+
# TODO: possible to OOM, find better way to detach
84+
self.named_modules_to_munmap[op_key].to(self.load_device).to(self.offload_device)
85+
del self.named_modules_to_munmap[op_key]
86+
super().pin_weight_to_device(key)
87+
7988
mmap_released = False
89+
8090
def load(self, *args, force_patch_weights=False, **kwargs):
91+
if not self.mmap_released:
92+
self.named_modules_to_munmap = dict(self.model.named_modules())
93+
8194
# always call `patch_weight_to_device` even for lowvram
8295
super().load(*args, force_patch_weights=True, **kwargs)
8396

8497
# make sure nothing stays linked to mmap after first load
8598
if not self.mmap_released:
8699
linked = []
87100
if kwargs.get("lowvram_model_memory", 0) > 0:
88-
for n, m in self.model.named_modules():
101+
for n, m in self.named_modules_to_munmap.items():
89102
if hasattr(m, "weight"):
90103
device = getattr(m.weight, "device", None)
91104
if device == self.offload_device:
@@ -102,6 +115,7 @@ def load(self, *args, force_patch_weights=False, **kwargs):
102115
# TODO: possible to OOM, find better way to detach
103116
m.to(self.load_device).to(self.offload_device)
104117
self.mmap_released = True
118+
self.named_modules_to_munmap = None
105119

106120
def clone(self, *args, **kwargs):
107121
src_cls = self.__class__

0 commit comments

Comments
 (0)