Skip to content

Commit 6cc6c13

Browse files
committed
fixes
1 parent 666a3d9 commit 6cc6c13

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,10 @@ def __init__(self, theta: int, axes_dim: List[int]):
12521252
self.axes_dim = axes_dim
12531253

12541254
def forward(self, ids: torch.Tensor) -> torch.Tensor:
1255+
was_unbatched = ids.ndim == 2
1256+
if was_unbatched:
1257+
# Add a batch dimension to standardize processing
1258+
ids = ids.unsqueeze(0)
12551259
# ids is now expected to be [B, S, n_axes]
12561260
n_axes = ids.shape[-1]
12571261
cos_out = []
@@ -1277,7 +1281,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
12771281
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
12781282

12791283
# Squeeze the batch dim if the original input was unbatched
1280-
if ids.ndim == 2:
1284+
if was_unbatched:
12811285
freqs_cos = freqs_cos.squeeze(0)
12821286
freqs_sin = freqs_sin.squeeze(0)
12831287

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -456,23 +456,19 @@ def forward(
456456
)
457457
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
458458

459-
# if txt_ids.ndim == 3:
460-
# logger.warning(
461-
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
462-
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
463-
# )
464-
# txt_ids = txt_ids[0]
465-
# if img_ids.ndim == 3:
466-
# logger.warning(
467-
# "Passing `img_ids` 3d torch.Tensor is deprecated."
468-
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
469-
# )
470-
# img_ids = img_ids[0]
471-
if txt_ids.ndim == 2:
472-
txt_ids = txt_ids.unsqueeze(0)
473-
if img_ids.ndim == 2:
474-
img_ids = img_ids.unsqueeze(0)
475-
ids = torch.cat((txt_ids, img_ids), dim=1)
459+
if txt_ids.ndim == 3:
460+
# logger.warning(
461+
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
462+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
463+
# )
464+
txt_ids = txt_ids[0]
465+
if img_ids.ndim == 3:
466+
# logger.warning(
467+
# "Passing `img_ids` 3d torch.Tensor is deprecated."
468+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
469+
# )
470+
img_ids = img_ids[0]
471+
ids = torch.cat((txt_ids, img_ids), dim=0)
476472
image_rotary_emb = self.pos_embed(ids)
477473

478474
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:

0 commit comments

Comments
 (0)