Skip to content

Commit ea17add

Browse files
Fix case where text encoders where running on the CPU instead of GPU. (#11095)
1 parent ecdc869 commit ea17add

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

comfy/sd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict
193193
self.cond_stage_model.set_clip_options({"projected_pooled": False})
194194

195195
self.load_model()
196+
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
196197
all_hooks.reset()
197198
self.patcher.patch_hooks(None)
198199
if show_pbar:
@@ -240,6 +241,7 @@ def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
240241
self.cond_stage_model.set_clip_options({"projected_pooled": False})
241242

242243
self.load_model()
244+
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
243245
o = self.cond_stage_model.encode_token_weights(tokens)
244246
cond, pooled = o[:2]
245247
if return_dict:

comfy/sd1_clip.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(self, device="cpu", max_length=77,
147147
self.layer_norm_hidden_state = layer_norm_hidden_state
148148
self.return_projected_pooled = return_projected_pooled
149149
self.return_attention_masks = return_attention_masks
150+
self.execution_device = None
150151

151152
if layer == "hidden":
152153
assert layer_idx is not None
@@ -163,6 +164,7 @@ def freeze(self):
163164
def set_clip_options(self, options):
164165
layer_idx = options.get("layer", self.layer_idx)
165166
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
167+
self.execution_device = options.get("execution_device", self.execution_device)
166168
if isinstance(self.layer, list) or self.layer == "all":
167169
pass
168170
elif layer_idx is None or abs(layer_idx) > self.num_layers:
@@ -175,6 +177,7 @@ def reset_clip_options(self):
175177
self.layer = self.options_default[0]
176178
self.layer_idx = self.options_default[1]
177179
self.return_projected_pooled = self.options_default[2]
180+
self.execution_device = None
178181

179182
def process_tokens(self, tokens, device):
180183
end_token = self.special_tokens.get("end", None)
@@ -258,7 +261,11 @@ def process_tokens(self, tokens, device):
258261
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
259262

260263
def forward(self, tokens):
261-
device = self.transformer.get_input_embeddings().weight.device
264+
if self.execution_device is None:
265+
device = self.transformer.get_input_embeddings().weight.device
266+
else:
267+
device = self.execution_device
268+
262269
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
263270

264271
attention_mask_model = None

0 commit comments

Comments
 (0)