Skip to content

Commit 021ba20

Browse files
Fix issue with parameters on root model object. (#12216)
1 parent b60be02 commit 021ba20

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

comfy/model_patcher.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def get_key_weight(model, key):
161161

162162
return weight, set_func, convert_func
163163

164+
def key_param_name_to_key(key, param):
165+
if len(key) == 0:
166+
return param
167+
return "{}.{}".format(key, param)
168+
164169
class AutoPatcherEjector:
165170
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
166171
self.model = model
@@ -795,7 +800,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
795800
continue
796801

797802
for param in params:
798-
key = "{}.{}".format(n, param)
803+
key = key_param_name_to_key(n, param)
799804
self.unpin_weight(key)
800805
self.patch_weight_to_device(key, device_to=device_to)
801806
if comfy.model_management.is_device_cuda(device_to):
@@ -811,7 +816,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
811816
n = x[1]
812817
params = x[3]
813818
for param in params:
814-
self.pin_weight_to_device("{}.{}".format(n, param))
819+
self.pin_weight_to_device(key_param_name_to_key(n, param))
815820

816821
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
817822
if lowvram_counter > 0:
@@ -917,7 +922,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
917922
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
918923
move_weight = True
919924
for param in params:
920-
key = "{}.{}".format(n, param)
925+
key = key_param_name_to_key(n, param)
921926
bk = self.backup.get(key, None)
922927
if bk is not None:
923928
if not lowvram_possible:
@@ -968,7 +973,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
968973
logging.debug("freed {}".format(n))
969974

970975
for param in params:
971-
self.pin_weight_to_device("{}.{}".format(n, param))
976+
self.pin_weight_to_device(key_param_name_to_key(n, param))
972977

973978

974979
self.model.model_lowvram = True
@@ -1501,7 +1506,7 @@ def set_dirty(item, dirty):
15011506

15021507
def setup_param(self, m, n, param_key):
15031508
nonlocal num_patches
1504-
key = "{}.{}".format(n, param_key)
1509+
key = key_param_name_to_key(n, param_key)
15051510

15061511
weight_function = []
15071512

@@ -1540,7 +1545,7 @@ def setup_param(self, m, n, param_key):
15401545

15411546
else:
15421547
for param in params:
1543-
key = "{}.{}".format(n, param)
1548+
key = key_param_name_to_key(n, param)
15441549
weight, _, _ = get_key_weight(self.model, key)
15451550
weight.seed_key = key
15461551
set_dirty(weight, dirty)

0 commit comments

Comments
 (0)