Skip to content

Commit 26cfacb

Browse files
authored
Merge branch 'comfyanonymous:master' into gempoll
2 parents 464a385 + a4ec54a commit 26cfacb

File tree

15 files changed

+232
-91
lines changed

15 files changed

+232
-91
lines changed

comfy/clip_model.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import torch
2+
from comfy.ldm.modules.attention import optimized_attention_for_device
3+
4+
class CLIPAttention(torch.nn.Module):
5+
def __init__(self, embed_dim, heads, dtype, device, operations):
6+
super().__init__()
7+
8+
self.heads = heads
9+
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
10+
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11+
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12+
13+
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
14+
15+
def forward(self, x, mask=None, optimized_attention=None):
16+
q = self.q_proj(x)
17+
k = self.k_proj(x)
18+
v = self.v_proj(x)
19+
20+
out = optimized_attention(q, k, v, self.heads, mask)
21+
return self.out_proj(out)
22+
23+
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
24+
"gelu": torch.nn.functional.gelu,
25+
}
26+
27+
class CLIPMLP(torch.nn.Module):
28+
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
29+
super().__init__()
30+
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
31+
self.activation = ACTIVATIONS[activation]
32+
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
33+
34+
def forward(self, x):
35+
x = self.fc1(x)
36+
x = self.activation(x)
37+
x = self.fc2(x)
38+
return x
39+
40+
class CLIPLayer(torch.nn.Module):
41+
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
42+
super().__init__()
43+
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
44+
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
45+
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
46+
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
47+
48+
def forward(self, x, mask=None, optimized_attention=None):
49+
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
50+
x += self.mlp(self.layer_norm2(x))
51+
return x
52+
53+
54+
class CLIPEncoder(torch.nn.Module):
55+
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
56+
super().__init__()
57+
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
58+
59+
def forward(self, x, mask=None, intermediate_output=None):
60+
optimized_attention = optimized_attention_for_device(x.device, mask=True)
61+
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
62+
if mask is not None:
63+
mask += causal_mask
64+
else:
65+
mask = causal_mask
66+
67+
if intermediate_output is not None:
68+
if intermediate_output < 0:
69+
intermediate_output = len(self.layers) + intermediate_output
70+
71+
intermediate = None
72+
for i, l in enumerate(self.layers):
73+
x = l(x, mask, optimized_attention)
74+
if i == intermediate_output:
75+
intermediate = x.clone()
76+
return x, intermediate
77+
78+
class CLIPEmbeddings(torch.nn.Module):
79+
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
80+
super().__init__()
81+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
82+
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
83+
84+
def forward(self, input_tokens):
85+
return self.token_embedding(input_tokens) + self.position_embedding.weight
86+
87+
88+
class CLIPTextModel_(torch.nn.Module):
89+
def __init__(self, config_dict, dtype, device, operations):
90+
num_layers = config_dict["num_hidden_layers"]
91+
embed_dim = config_dict["hidden_size"]
92+
heads = config_dict["num_attention_heads"]
93+
intermediate_size = config_dict["intermediate_size"]
94+
intermediate_activation = config_dict["hidden_act"]
95+
96+
super().__init__()
97+
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
98+
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
99+
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
100+
101+
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
102+
x = self.embeddings(input_tokens)
103+
mask = None
104+
if attention_mask is not None:
105+
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
106+
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
107+
108+
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
109+
x = self.final_layer_norm(x)
110+
if i is not None and final_layer_norm_intermediate:
111+
i = self.final_layer_norm(i)
112+
113+
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
114+
return x, i, pooled_output
115+
116+
class CLIPTextModel(torch.nn.Module):
117+
def __init__(self, config_dict, dtype, device, operations):
118+
super().__init__()
119+
self.num_layers = config_dict["num_hidden_layers"]
120+
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
121+
self.dtype = dtype
122+
123+
def get_input_embeddings(self):
124+
return self.text_model.embeddings.token_embedding
125+
126+
def set_input_embeddings(self, embeddings):
127+
self.text_model.embeddings.token_embedding = embeddings
128+
129+
def forward(self, *args, **kwargs):
130+
return self.text_model(*args, **kwargs)

comfy/clip_vision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def encode_image(self, image):
5454
t = outputs[k]
5555
if t is not None:
5656
if k == 'hidden_states':
57-
outputs["penultimate_hidden_states"] = t[-2].cpu()
57+
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
5858
outputs["hidden_states"] = None
5959
else:
60-
outputs[k] = t.cpu()
60+
outputs[k] = t.to(comfy.model_management.intermediate_device())
6161

6262
return outputs
6363

comfy/ldm/modules/attention.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None):
112112
del q, k
113113

114114
if exists(mask):
115-
mask = rearrange(mask, 'b ... -> b (...)')
116-
max_neg_value = -torch.finfo(sim.dtype).max
117-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
118-
sim.masked_fill_(~mask, max_neg_value)
115+
if mask.dtype == torch.bool:
116+
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
117+
max_neg_value = -torch.finfo(sim.dtype).max
118+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
119+
sim.masked_fill_(~mask, max_neg_value)
120+
else:
121+
sim += mask
119122

120123
# attention, what we cannot get enough of
121124
sim = sim.softmax(dim=-1)
@@ -340,6 +343,18 @@ def attention_pytorch(q, k, v, heads, mask=None):
340343
if model_management.pytorch_attention_enabled():
341344
optimized_attention_masked = attention_pytorch
342345

346+
def optimized_attention_for_device(device, mask=False):
347+
if device == torch.device("cpu"): #TODO
348+
if model_management.pytorch_attention_enabled():
349+
return attention_pytorch
350+
else:
351+
return attention_basic
352+
if mask:
353+
return optimized_attention_masked
354+
355+
return optimized_attention
356+
357+
343358
class CrossAttention(nn.Module):
344359
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
345360
super().__init__()

comfy/model_management.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ def text_encoder_dtype(device=None):
508508
else:
509509
return torch.float32
510510

511+
def intermediate_device():
512+
if args.gpu_only:
513+
return get_torch_device()
514+
else:
515+
return torch.device("cpu")
516+
511517
def vae_device():
512518
return get_torch_device()
513519

comfy/model_sampling.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,17 @@ def calculate_denoised(self, sigma, model_output, model_input):
2222
class ModelSamplingDiscrete(torch.nn.Module):
2323
def __init__(self, model_config=None):
2424
super().__init__()
25-
beta_schedule = "linear"
25+
2626
if model_config is not None:
27-
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
28-
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
27+
sampling_settings = model_config.sampling_settings
28+
else:
29+
sampling_settings = {}
30+
31+
beta_schedule = sampling_settings.get("beta_schedule", "linear")
32+
linear_start = sampling_settings.get("linear_start", 0.00085)
33+
linear_end = sampling_settings.get("linear_end", 0.012)
34+
35+
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
2936
self.sigma_data = 1.0
3037

3138
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,

comfy/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
9898
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
9999

100100
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
101-
samples = samples.cpu()
101+
samples = samples.to(comfy.model_management.intermediate_device())
102102

103103
cleanup_additional_models(models)
104104
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
@@ -111,7 +111,7 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
111111
sigmas = sigmas.to(model.load_device)
112112

113113
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
114-
samples = samples.cpu()
114+
samples = samples.to(comfy.model_management.intermediate_device())
115115
cleanup_additional_models(models)
116116
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
117117
return samples

comfy/sd.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(self, sd=None, device=None, config=None):
190190
offload_device = model_management.vae_offload_device()
191191
self.vae_dtype = model_management.vae_dtype()
192192
self.first_stage_model.to(self.vae_dtype)
193+
self.output_device = model_management.intermediate_device()
193194

194195
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
195196

@@ -201,9 +202,9 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
201202

202203
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
203204
output = torch.clamp((
204-
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
205-
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
206-
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
205+
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
206+
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
207+
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar))
207208
/ 3.0) / 2.0, min=0.0, max=1.0)
208209
return output
209210

@@ -214,9 +215,9 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
214215
pbar = comfy.utils.ProgressBar(steps)
215216

216217
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
217-
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
218-
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
219-
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
218+
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
219+
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
220+
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
220221
samples /= 3.0
221222
return samples
222223

@@ -228,15 +229,15 @@ def decode(self, samples_in):
228229
batch_number = int(free_memory / memory_used)
229230
batch_number = max(1, batch_number)
230231

231-
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
232+
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device)
232233
for x in range(0, samples_in.shape[0], batch_number):
233234
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
234-
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
235+
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
235236
except model_management.OOM_EXCEPTION as e:
236237
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
237238
pixel_samples = self.decode_tiled_(samples_in)
238239

239-
pixel_samples = pixel_samples.cpu().movedim(1,-1)
240+
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
240241
return pixel_samples
241242

242243
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@@ -252,10 +253,10 @@ def encode(self, pixel_samples):
252253
free_memory = model_management.get_free_memory(self.device)
253254
batch_number = int(free_memory / memory_used)
254255
batch_number = max(1, batch_number)
255-
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
256+
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device)
256257
for x in range(0, pixel_samples.shape[0], batch_number):
257258
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
258-
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
259+
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
259260

260261
except model_management.OOM_EXCEPTION as e:
261262
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")

0 commit comments

Comments
 (0)