Skip to content

Commit 6fd463a

Browse files
Fix regression when text encoder loaded directly on GPU. (#11129)
1 parent 43071e3 commit 6fd463a

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

comfy/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
552552

553553
weight_scale_key = f"{prefix}weight_scale"
554554
scale = state_dict.pop(weight_scale_key, None)
555+
if scale is not None:
556+
scale = scale.to(device)
555557
layout_params = {
556558
'scale': scale,
557559
'orig_dtype': MixedPrecisionOps._compute_dtype,

comfy/sd.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
9898

9999

100100
class CLIP:
101-
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
101+
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
102102
if no_init:
103103
return
104104
params = target.params.copy()
@@ -129,6 +129,27 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz
129129
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
130130
self.patcher.is_clip = True
131131
self.apply_hooks_to_conds = None
132+
if len(state_dict) > 0:
133+
if isinstance(state_dict, list):
134+
for c in state_dict:
135+
m, u = self.load_sd(c)
136+
if len(m) > 0:
137+
logging.warning("clip missing: {}".format(m))
138+
139+
if len(u) > 0:
140+
logging.debug("clip unexpected: {}".format(u))
141+
else:
142+
m, u = self.load_sd(state_dict, full_model=True)
143+
if len(m) > 0:
144+
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
145+
if len(m_filter) > 0:
146+
logging.warning("clip missing: {}".format(m))
147+
else:
148+
logging.debug("clip missing: {}".format(m))
149+
150+
if len(u) > 0:
151+
logging.debug("clip unexpected {}:".format(u))
152+
132153
if params['device'] == load_device:
133154
model_management.load_models_gpu([self.patcher], force_full_load=True)
134155
self.layer_idx = None
@@ -1225,14 +1246,7 @@ class EmptyClass:
12251246
parameters += comfy.utils.calculate_parameters(c)
12261247
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
12271248

1228-
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
1229-
for c in clip_data:
1230-
m, u = clip.load_sd(c)
1231-
if len(m) > 0:
1232-
logging.warning("clip missing: {}".format(m))
1233-
1234-
if len(u) > 0:
1235-
logging.debug("clip unexpected: {}".format(u))
1249+
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
12361250
return clip
12371251

12381252
def load_gligen(ckpt_path):
@@ -1362,17 +1376,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
13621376
clip_sd = model_config.process_clip_state_dict(sd)
13631377
if len(clip_sd) > 0:
13641378
parameters = comfy.utils.calculate_parameters(clip_sd)
1365-
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
1366-
m, u = clip.load_sd(clip_sd, full_model=True)
1367-
if len(m) > 0:
1368-
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
1369-
if len(m_filter) > 0:
1370-
logging.warning("clip missing: {}".format(m))
1371-
else:
1372-
logging.debug("clip missing: {}".format(m))
1373-
1374-
if len(u) > 0:
1375-
logging.debug("clip unexpected {}:".format(u))
1379+
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
13761380
else:
13771381
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
13781382

0 commit comments

Comments
 (0)