-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathmodel.py
More file actions
398 lines (318 loc) · 14.2 KB
/
model.py
File metadata and controls
398 lines (318 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import math
from collections import OrderedDict
from typing import Optional
import torch
import os
from safetensors.torch import load_file
from comfy import model_management
from comfy.model_patcher import ModelPatcher
from safetensors.torch import load_model
from torch import nn
from .activations import get_activation
from .utils import patch_device_empty_setter, remove_weights
ELLA_DEBUG = os.getenv("ELLA_DEBUG", "0") in ("1", "true", "True")
def _count_params(m: torch.nn.Module):
total = sum(p.numel() for p in m.parameters())
trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
return total, trainable
def _size_mb(m: torch.nn.Module):
# estimate the size of parameters in MB
bytes_total = sum(p.numel() * p.element_size() for p in m.parameters())
return bytes_total / (1024 ** 2)
def load_model_lenient(model: torch.nn.Module, path: str):
sd_file = load_file(path) # dict[name -> Tensor]
model_sd = model.state_dict()
new_sd = {}
skipped_shape = []
extra = []
casted = []
for k, v in sd_file.items():
if k in model_sd:
if model_sd[k].shape == v.shape:
if model_sd[k].dtype != v.dtype:
v = v.to(model_sd[k].dtype)
casted.append(k)
# transfer to the parameter device
if v.device != model_sd[k].device:
v = v.to(model_sd[k].device)
new_sd[k] = v
else:
skipped_shape.append((k, tuple(v.shape), tuple(model_sd[k].shape)))
else:
extra.append(k)
if skipped_shape:
print(f"[ELLA/load] skipped by shape: {len(skipped_shape)} (e.g. {skipped_shape[:3]})")
if extra:
print(f"[ELLA/load] extra keys in ckpt: {len(extra)} (e.g. {extra[:5]})")
if casted:
print(f"[ELLA/load] dtype casted: {len(casted)} (e.g. {casted[:5]})")
missing = [k for k in model_sd.keys() if k not in new_sd]
if missing:
print(f"[ELLA/load] missing in ckpt: {len(missing)} (e.g. {missing[:5]})")
model.load_state_dict(new_sd, strict=False)
return model
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int, time_embedding_dim: Optional[int] = None):
super().__init__()
if time_embedding_dim is None:
time_embedding_dim = embedding_dim
self.silu = nn.SiLU()
self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, timestep_embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(timestep_embedding))
shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1)
return self.norm(x) * (1 + scale) + shift
class SquaredReLU(nn.Module):
def forward(self, x: torch.Tensor):
return torch.square(torch.relu(x))
class PerceiverAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, time_embedding_dim: Optional[int] = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("sq_relu", SquaredReLU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_1 = AdaLayerNorm(d_model, time_embedding_dim)
self.ln_2 = AdaLayerNorm(d_model, time_embedding_dim)
self.ln_ff = AdaLayerNorm(d_model, time_embedding_dim)
def attention(self, q: torch.Tensor, kv: torch.Tensor):
attn_output, attn_output_weights = self.attn(q, kv, kv, need_weights=False)
return attn_output
def forward(
self,
x: torch.Tensor,
latents: torch.Tensor,
timestep_embedding: Optional[torch.Tensor] = None,
):
normed_latents = self.ln_1(latents, timestep_embedding)
latents = latents + self.attention(
q=normed_latents,
kv=torch.cat([normed_latents, self.ln_2(x, timestep_embedding)], dim=1),
)
return latents + self.mlp(self.ln_ff(latents, timestep_embedding))
class PerceiverResampler(nn.Module):
def __init__(
self,
width: int = 768,
layers: int = 6,
heads: int = 8,
num_latents: int = 64,
output_dim=None,
input_dim=None,
time_embedding_dim: Optional[int] = None,
):
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.latents = nn.Parameter(width**-0.5 * torch.randn(num_latents, width))
self.time_aware_linear = nn.Linear(time_embedding_dim or width, width, bias=True)
if self.input_dim is not None:
self.proj_in = nn.Linear(input_dim, width) # type: ignore
self.perceiver_blocks = nn.Sequential(
*[PerceiverAttentionBlock(width, heads, time_embedding_dim=time_embedding_dim) for _ in range(layers)]
)
if self.output_dim is not None:
self.proj_out = nn.Sequential(nn.Linear(width, output_dim), nn.LayerNorm(output_dim)) # type: ignore
def forward(self, x: torch.Tensor, timestep_embedding: torch.Tensor = None): # type: ignore
learnable_latents = self.latents.unsqueeze(dim=0).repeat(len(x), 1, 1)
latents = learnable_latents + self.time_aware_linear(torch.nn.functional.silu(timestep_embedding))
if self.input_dim is not None:
x = self.proj_in(x)
for p_block in self.perceiver_blocks:
latents = p_block(x, latents, timestep_embedding=timestep_embedding)
if self.output_dim is not None:
latents = self.proj_out(latents)
return latents
class T5TextEmbedder:
def __init__(self, pretrained_path="google/flan-t5-xl", max_length=None, dtype=None, legacy=True):
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = dtype if dtype is not None else model_management.text_encoder_dtype(self.load_device)
self.output_device = model_management.intermediate_device()
self.max_length = max_length
from transformers import T5EncoderModel, T5Tokenizer
self.model = T5EncoderModel.from_pretrained(pretrained_path).to(self.dtype) # type: ignore
patch_device_empty_setter(self.model.__class__)
self.tokenizer = T5Tokenizer.from_pretrained(pretrained_path, legacy=legacy)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
def load_model(self):
model_management.load_model_gpu(self.patcher)
return self.patcher
def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length=None, **kwargs):
self.load_model()
model_device = self.model.device
# remove a1111/comfyui prompt weight, t5 embedder currently does not accept weight
caption = remove_weights(caption)
if max_length is None:
max_length = self.max_length
if text_input_ids is None or attention_mask is None:
if max_length is not None:
text_inputs = self.tokenizer(
caption,
return_tensors="pt",
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
)
else:
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True)
# Ensure tensors are on the correct device
text_input_ids = text_inputs.input_ids.to(model_device)
attention_mask = text_inputs.attention_mask.to(model_device)
else:
# Ensure provided tensors are on the correct device
text_input_ids = text_input_ids.to(model_device)
attention_mask = attention_mask.to(model_device)
# Ensure model is on the correct device
self.model.to(model_device)
outputs = self.model(text_input_ids, attention_mask=attention_mask)
# Move output to the specified output device
return outputs.last_hidden_state.to(self.output_device)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: Optional[int] = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition) # type: ignore
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
class ELLAModel(nn.Module):
def __init__(
self,
time_channel=320,
time_embed_dim=768,
act_fn: str = "silu",
out_dim: Optional[int] = None,
width=768,
layers=6,
heads=8,
num_latents=64,
input_dim=2048,
):
super().__init__()
self.position = Timesteps(time_channel, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedding = TimestepEmbedding(
in_channels=time_channel,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=out_dim, # type: ignore
)
self.connector = PerceiverResampler(
width=width,
layers=layers,
heads=heads,
num_latents=num_latents,
input_dim=input_dim,
time_embedding_dim=time_embed_dim,
)
def forward(self, timesteps: torch.Tensor, t5_embeds: torch.Tensor, **kwargs):
device = t5_embeds.device
dtype = t5_embeds.dtype
ori_time_feature = self.position(timesteps.view(-1)).to(device, dtype=dtype)
ori_time_feature = ori_time_feature.unsqueeze(dim=1) if ori_time_feature.ndim == 2 else ori_time_feature
ori_time_feature = ori_time_feature.expand(len(t5_embeds), -1, -1)
time_embedding = self.time_embedding(ori_time_feature)
return self.connector(t5_embeds, timestep_embedding=time_embedding)
class ELLA:
def __init__(self, path: str, **kwargs) -> None:
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = model_management.text_encoder_dtype(self.load_device)
self.output_device = model_management.intermediate_device()
self.model = ELLAModel()
load_model_lenient(self.model, path)
self.model.to(dtype=torch.float16)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
def load_model(self):
model_management.load_model_gpu(self.patcher)
return self.patcher
def __call__(self, timesteps: torch.Tensor, t5_embeds: torch.Tensor, **kwargs):
self.load_model()
timesteps = timesteps.to(device=self.load_device, dtype=torch.int64)
t5_embeds = t5_embeds.to(device=self.load_device, dtype=self.dtype) # type: ignore
cond = self.model(timesteps, t5_embeds, **kwargs)
return cond.to(self.output_device)