Skip to content

Commit ddec8fb

Browse files
CosmosAttnProcessor2_0 revert + CosmosAttnProcessor2_5 changes
1 parent 4b38767 commit ddec8fb

File tree

2 files changed

+88
-71
lines changed

2 files changed

+88
-71
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,15 @@
104104
--transformer_type Cosmos-2.5-Transfer-General-2B \
105105
--transformer_ckpt_path $transformer_ckpt_path \
106106
--vae_type wan2.1 \
107-
--output_path converted/transfer/2b/general/edge \
107+
--output_path converted/transfer/2b/general/edge/pipeline \
108108
--save_pipeline
109109
110+
python scripts/convert_cosmos_to_diffusers.py \
111+
--transformer_type Cosmos-2.5-Transfer-General-2B \
112+
--transformer_ckpt_path $transformer_ckpt_path \
113+
--vae_type wan2.1 \
114+
--output_path converted/transfer/2b/general/edge/models
115+
110116
# blur
111117
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/blur/ba2f44f2-c726-4fe7-949f-597069d9b91c_ema_bf16.pt
112118
@@ -903,7 +909,7 @@ def get_args():
903909
controlnet = controlnet.to(dtype=dtype)
904910

905911
if not args.save_pipeline:
906-
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
912+
transformer.save_pretrained(pathlib.Path(args.output_path) / "transformer", safe_serialization=True, max_shard_size="5GB")
907913
controlnet.save_pretrained(
908914
pathlib.Path(args.output_path) / "controlnet", safe_serialization=True, max_shard_size="5GB"
909915
)
@@ -943,8 +949,7 @@ def get_args():
943949
if "Predict" in args.transformer_type:
944950
save_pipeline_cosmos2_5_predict(args, transformer, vae)
945951
elif "Transfer" in args.transformer_type:
946-
assert controlnet is not None
947-
save_pipeline_cosmos2_5_transfer(args, transformer, controlnet, vae)
952+
save_pipeline_cosmos2_5_transfer(args, transformer, None, vae)
948953
else:
949954
raise AssertionError(f"{args.transformer_type} not supported")
950955
else:

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import numpy as np
1818
import torch
1919
import torch.nn as nn
20-
import torch.nn.functional as F
2120

2221
from ...configuration_utils import ConfigMixin, register_to_config
2322
from ...loaders import FromOriginalModelMixin
2423
from ...utils import is_torchvision_available
2524
from ..attention import FeedForward
25+
from ..attention_dispatch import dispatch_attention_fn
2626
from ..attention_processor import Attention
2727
from ..embeddings import Timesteps
2828
from ..modeling_outputs import Transformer2DModelOutput
@@ -152,10 +152,10 @@ def forward(
152152

153153
class CosmosAttnProcessor2_0:
154154
def __init__(self):
155-
if not hasattr(F, "scaled_dot_product_attention"):
155+
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
156156
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
157157

158-
def compute_attn(
158+
def __call__(
159159
self,
160160
attn: Attention,
161161
hidden_states: torch.Tensor,
@@ -199,70 +199,26 @@ def compute_attn(
199199
value = value.repeat_interleave(query_idx // value_idx, dim=3)
200200

201201
# 5. Attention
202-
hidden_states = F.scaled_dot_product_attention(
203-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
204-
)
205-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
206-
return hidden_states
207-
208-
def __call__(
209-
self,
210-
attn: Attention,
211-
hidden_states: torch.Tensor,
212-
encoder_hidden_states: Optional[torch.Tensor] = None,
213-
attention_mask: Optional[torch.Tensor] = None,
214-
image_rotary_emb: Optional[torch.Tensor] = None,
215-
) -> torch.Tensor:
216-
hidden_states = self.compute_attn(
217-
attn=attn,
218-
hidden_states=hidden_states,
219-
encoder_hidden_states=encoder_hidden_states,
220-
attention_mask=attention_mask,
221-
image_rotary_emb=image_rotary_emb,
202+
hidden_states = dispatch_attention_fn(
203+
query.transpose(1, 2),
204+
key.transpose(1, 2),
205+
value.transpose(1, 2),
206+
attn_mask=attention_mask,
207+
dropout_p=0.0,
208+
is_causal=False,
222209
)
210+
hidden_states = hidden_states.flatten(2, 3).type_as(query)
223211
hidden_states = attn.to_out[0](hidden_states)
224212
hidden_states = attn.to_out[1](hidden_states)
225213

226214
return hidden_states
227215

228216

229-
class CosmosAttnProcessor2_5(CosmosAttnProcessor2_0):
217+
class CosmosAttnProcessor2_5:
230218
def __init__(self):
231219
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
232220
raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.")
233221

234-
def compute_attn_i2v(
235-
self,
236-
attn: Attention,
237-
hidden_states: torch.Tensor,
238-
img_context=None,
239-
attention_mask=None,
240-
):
241-
q_img = attn.q_img(hidden_states)
242-
k_img = attn.k_img(img_context)
243-
v_img = attn.v_img(img_context)
244-
245-
batch_size = hidden_states.shape[0]
246-
247-
dim_head = attn.out_dim // attn.heads
248-
q_img = q_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
249-
k_img = k_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
250-
v_img = v_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
251-
252-
q_img = attn.q_img_norm(q_img)
253-
k_img = attn.k_img_norm(k_img)
254-
255-
q_img_idx = q_img.size(3)
256-
k_img_idx = k_img.size(3)
257-
v_img_idx = v_img.size(3)
258-
k_img = k_img.repeat_interleave(q_img_idx // k_img_idx, dim=3)
259-
v_img = v_img.repeat_interleave(q_img_idx // v_img_idx, dim=3)
260-
img_out = torch.nn.functional.scaled_dot_product_attention(
261-
q_img, k_img, v_img, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
262-
)
263-
img_out = img_out.transpose(1, 2).flatten(2, 3).type_as(q_img)
264-
return img_out
265-
266222
def __call__(
267223
self,
268224
attn: Attention,
@@ -277,21 +233,77 @@ def __call__(
277233
text_context, img_context = encoder_hidden_states if encoder_hidden_states else (None, None)
278234
text_mask, img_mask = attention_mask if attention_mask else (None, None)
279235

280-
attn_out = self.compute_attn(
281-
attn=attn,
282-
hidden_states=hidden_states,
283-
encoder_hidden_states=text_context,
284-
attention_mask=text_mask,
285-
image_rotary_emb=image_rotary_emb,
236+
if text_context is None:
237+
text_context = hidden_states
238+
239+
query = attn.to_q(hidden_states)
240+
key = attn.to_k(text_context)
241+
value = attn.to_v(text_context)
242+
243+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
244+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
245+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
246+
247+
query = attn.norm_q(query)
248+
key = attn.norm_k(key)
249+
250+
if image_rotary_emb is not None:
251+
from ..embeddings import apply_rotary_emb
252+
253+
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
254+
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
255+
256+
if torch.onnx.is_in_onnx_export():
257+
query_idx = torch.tensor(query.size(3), device=query.device)
258+
key_idx = torch.tensor(key.size(3), device=key.device)
259+
value_idx = torch.tensor(value.size(3), device=value.device)
260+
else:
261+
query_idx = query.size(3)
262+
key_idx = key.size(3)
263+
value_idx = value.size(3)
264+
key = key.repeat_interleave(query_idx // key_idx, dim=3)
265+
value = value.repeat_interleave(query_idx // value_idx, dim=3)
266+
267+
attn_out = dispatch_attention_fn(
268+
query.transpose(1, 2),
269+
key.transpose(1, 2),
270+
value.transpose(1, 2),
271+
attn_mask=text_mask,
272+
dropout_p=0.0,
273+
is_causal=False,
286274
)
275+
attn_out = attn_out.flatten(2, 3).type_as(query)
287276

288277
if img_context is not None:
289-
img_out = self.compute_attn_i2v(
290-
attn=attn,
291-
hidden_states=hidden_states,
292-
img_context=img_context,
293-
attention_mask=img_mask,
278+
q_img = attn.q_img(hidden_states)
279+
k_img = attn.k_img(img_context)
280+
v_img = attn.v_img(img_context)
281+
282+
batch_size = hidden_states.shape[0]
283+
dim_head = attn.out_dim // attn.heads
284+
285+
q_img = q_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
286+
k_img = k_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
287+
v_img = v_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
288+
289+
q_img = attn.q_img_norm(q_img)
290+
k_img = attn.k_img_norm(k_img)
291+
292+
q_img_idx = q_img.size(3)
293+
k_img_idx = k_img.size(3)
294+
v_img_idx = v_img.size(3)
295+
k_img = k_img.repeat_interleave(q_img_idx // k_img_idx, dim=3)
296+
v_img = v_img.repeat_interleave(q_img_idx // v_img_idx, dim=3)
297+
298+
img_out = dispatch_attention_fn(
299+
q_img.transpose(1, 2),
300+
k_img.transpose(1, 2),
301+
v_img.transpose(1, 2),
302+
attn_mask=img_mask,
303+
dropout_p=0.0,
304+
is_causal=False,
294305
)
306+
img_out = img_out.flatten(2, 3).type_as(q_img)
295307
hidden_states = attn_out + img_out
296308
else:
297309
hidden_states = attn_out

0 commit comments

Comments
 (0)