Skip to content

Commit b89b5d1

Browse files
committed
feat: add batching support in Flux RoPE for metaqueries
1 parent 425a715 commit b89b5d1

File tree

3 files changed

+58
-26
lines changed

3 files changed

+58
-26
lines changed

check_rope_batched.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from diffusers.models.embeddings import FluxPosEmbed
2+
import torch
3+
4+
batch_size = 4
5+
seq_length = 16
6+
img_seq_length = 32
7+
txt_ids = torch.randn(batch_size, seq_length, 3)
8+
img_ids = torch.randn(batch_size, img_seq_length, 3)
9+
10+
pos_embed = FluxPosEmbed(theta=10000, axes_dim=[4, 4, 8])
11+
ids = torch.cat((txt_ids, img_ids), dim=1)
12+
image_rotary_emb = pos_embed(ids)
13+
# image_rotary_emb[0].shape=torch.Size([4, 48, 16]), image_rotary_emb[1].shape=torch.Size([4, 48, 16])
14+
print(f"{image_rotary_emb[0].shape=}, {image_rotary_emb[1].shape=}")

src/diffusers/models/embeddings.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,32 +1142,38 @@ def get_1d_rotary_pos_embed(
11421142
"""
11431143
assert dim % 2 == 0
11441144

1145-
if isinstance(pos, int):
1146-
pos = torch.arange(pos)
1147-
if isinstance(pos, np.ndarray):
1148-
pos = torch.from_numpy(pos) # type: ignore # [S]
1145+
# Handle both batched [B, S] and un-batched [S] inputs
1146+
if pos.ndim == 1:
1147+
pos = pos.unsqueeze(0) # Add a batch dimension if missing
11491148

11501149
theta = theta * ntk_factor
11511150
freqs = (
11521151
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
1153-
) # [D/2]
1154-
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
1152+
) # Shape: [D/2]
1153+
1154+
# Replace torch.outer with broadcasted multiplication
1155+
# Old: freqs = torch.outer(pos, freqs) # Shape: [S, D/2]
1156+
# New: pos is [B, S], freqs is [D/2]. Unsqueeze pos to [B, S, 1] for broadcasting.
1157+
freqs = pos.unsqueeze(-1) * freqs # Shape: [B, S, D/2]
1158+
11551159
is_npu = freqs.device.type == "npu"
11561160
if is_npu:
11571161
freqs = freqs.float()
1162+
11581163
if use_real and repeat_interleave_real:
11591164
# flux, hunyuan-dit, cogvideox
1160-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1161-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1165+
# Use dim=-1 for robust interleaving on the feature dimension
1166+
freqs_cos = freqs.cos().repeat_interleave(2, dim=-1) # Shape: [B, S, D]
1167+
freqs_sin = freqs.sin().repeat_interleave(2, dim=-1) # Shape: [B, S, D]
11621168
return freqs_cos, freqs_sin
11631169
elif use_real:
11641170
# stable audio, allegro
1165-
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
1166-
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
1171+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # Shape: [B, S, D]
1172+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # Shape: [B, S, D]
11671173
return freqs_cos, freqs_sin
11681174
else:
11691175
# lumina
1170-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
1176+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Shape: [B, S, D/2]
11711177
return freqs_cis
11721178

11731179

@@ -1246,26 +1252,35 @@ def __init__(self, theta: int, axes_dim: List[int]):
12461252
self.axes_dim = axes_dim
12471253

12481254
def forward(self, ids: torch.Tensor) -> torch.Tensor:
1255+
# ids is now expected to be [B, S, n_axes]
12491256
n_axes = ids.shape[-1]
12501257
cos_out = []
12511258
sin_out = []
12521259
pos = ids.float()
12531260
is_mps = ids.device.type == "mps"
12541261
is_npu = ids.device.type == "npu"
12551262
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1263+
12561264
for i in range(n_axes):
12571265
cos, sin = get_1d_rotary_pos_embed(
12581266
self.axes_dim[i],
1259-
pos[:, i],
1267+
pos[:, :, i], # Correct slicing for batched input
12601268
theta=self.theta,
12611269
repeat_interleave_real=True,
12621270
use_real=True,
12631271
freqs_dtype=freqs_dtype,
12641272
)
12651273
cos_out.append(cos)
12661274
sin_out.append(sin)
1275+
12671276
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
12681277
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
1278+
1279+
# Squeeze the batch dim if the original input was unbatched
1280+
if ids.ndim == 2:
1281+
freqs_cos = freqs_cos.squeeze(0)
1282+
freqs_sin = freqs_sin.squeeze(0)
1283+
12691284
return freqs_cos, freqs_sin
12701285

12711286

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -456,20 +456,23 @@ def forward(
456456
)
457457
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
458458

459-
if txt_ids.ndim == 3:
460-
logger.warning(
461-
"Passing `txt_ids` 3d torch.Tensor is deprecated."
462-
"Please remove the batch dimension and pass it as a 2d torch Tensor"
463-
)
464-
txt_ids = txt_ids[0]
465-
if img_ids.ndim == 3:
466-
logger.warning(
467-
"Passing `img_ids` 3d torch.Tensor is deprecated."
468-
"Please remove the batch dimension and pass it as a 2d torch Tensor"
469-
)
470-
img_ids = img_ids[0]
471-
472-
ids = torch.cat((txt_ids, img_ids), dim=0)
459+
# if txt_ids.ndim == 3:
460+
# logger.warning(
461+
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
462+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
463+
# )
464+
# txt_ids = txt_ids[0]
465+
# if img_ids.ndim == 3:
466+
# logger.warning(
467+
# "Passing `img_ids` 3d torch.Tensor is deprecated."
468+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
469+
# )
470+
# img_ids = img_ids[0]
471+
if txt_ids.ndim == 2:
472+
txt_ids = txt_ids.unsqueeze(0)
473+
if img_ids.ndim == 2:
474+
img_ids = img_ids.unsqueeze(0)
475+
ids = torch.cat((txt_ids, img_ids), dim=1)
473476
image_rotary_emb = self.pos_embed(ids)
474477

475478
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:

0 commit comments

Comments
 (0)