Skip to content

Commit ca582bc

Browse files
committed
Merge branch 'main' into stable
2 parents 10606d7 + 100c06c commit ca582bc

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

nodes.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,29 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
7676
# TODO: Find another way to not unload after patches
7777
return super().unpatch_model(device_to=device_to, unpatch_weights=unpatch_weights)
7878

79+
80+
def pin_weight_to_device(self, key):
81+
op_key = key.rsplit('.', 1)[0]
82+
if self.named_modules_to_munmap is not None and op_key in self.named_modules_to_munmap:
83+
# TODO: possible to OOM, find better way to detach
84+
self.named_modules_to_munmap[op_key].to(self.load_device).to(self.offload_device)
85+
del self.named_modules_to_munmap[op_key]
86+
super().pin_weight_to_device(key)
87+
7988
mmap_released = False
89+
8090
def load(self, *args, force_patch_weights=False, **kwargs):
91+
if not self.mmap_released:
92+
self.named_modules_to_munmap = dict(self.model.named_modules())
93+
8194
# always call `patch_weight_to_device` even for lowvram
8295
super().load(*args, force_patch_weights=True, **kwargs)
8396

8497
# make sure nothing stays linked to mmap after first load
8598
if not self.mmap_released:
8699
linked = []
87100
if kwargs.get("lowvram_model_memory", 0) > 0:
88-
for n, m in self.model.named_modules():
101+
for n, m in self.named_modules_to_munmap.items():
89102
if hasattr(m, "weight"):
90103
device = getattr(m.weight, "device", None)
91104
if device == self.offload_device:
@@ -102,6 +115,7 @@ def load(self, *args, force_patch_weights=False, **kwargs):
102115
# TODO: possible to OOM, find better way to detach
103116
m.to(self.load_device).to(self.offload_device)
104117
self.mmap_released = True
118+
self.named_modules_to_munmap = None
105119

106120
def clone(self, *args, **kwargs):
107121
src_cls = self.__class__

ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars):
153153
# Take into account space required for dequantizing the largest tensor
154154
if self.largest_layer:
155155
shape = getattr(self.weight, "tensor_shape", self.weight.shape)
156-
dtype = self.dequant_dtype or torch.float16
156+
dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16
157157
temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
158158
destination[prefix + "temp.weight"] = temp
159159

0 commit comments

Comments
 (0)