Skip to content

Commit 7e97e43

Browse files
committed
update
1 parent d7b9e42 commit 7e97e43

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

src/diffusers/models/embeddings.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,37 +1238,6 @@ def apply_1d_rope(tokens, pos, cos, sin):
12381238
return x
12391239

12401240

1241-
class FluxPosEmbed(nn.Module):
1242-
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
1243-
def __init__(self, theta: int, axes_dim: List[int]):
1244-
super().__init__()
1245-
self.theta = theta
1246-
self.axes_dim = axes_dim
1247-
1248-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
1249-
n_axes = ids.shape[-1]
1250-
cos_out = []
1251-
sin_out = []
1252-
pos = ids.float()
1253-
is_mps = ids.device.type == "mps"
1254-
is_npu = ids.device.type == "npu"
1255-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1256-
for i in range(n_axes):
1257-
cos, sin = get_1d_rotary_pos_embed(
1258-
self.axes_dim[i],
1259-
pos[:, i],
1260-
theta=self.theta,
1261-
repeat_interleave_real=True,
1262-
use_real=True,
1263-
freqs_dtype=freqs_dtype,
1264-
)
1265-
cos_out.append(cos)
1266-
sin_out.append(sin)
1267-
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
1268-
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
1269-
return freqs_cos, freqs_sin
1270-
1271-
12721241
class TimestepEmbedding(nn.Module):
12731242
def __init__(
12741243
self,
@@ -2619,3 +2588,13 @@ def forward(self, image_embeds: List[torch.Tensor]):
26192588
projected_image_embeds.append(image_embed)
26202589

26212590
return projected_image_embeds
2591+
2592+
2593+
class FluxPosEmbed(nn.Module):
2594+
def __new__(cls, *args, **kwargs):
2595+
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
2596+
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
2597+
2598+
from .transformers.transformer_flux import FluxPosEmbed
2599+
2600+
return FluxPosEmbed(*args, **kwargs)

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from ..embeddings import (
3131
CombinedTimestepGuidanceTextProjEmbeddings,
3232
CombinedTimestepTextProjEmbeddings,
33-
FluxPosEmbed,
3433
apply_rotary_emb,
34+
get_1d_rotary_pos_embed,
3535
)
3636
from ..modeling_outputs import Transformer2DModelOutput
3737
from ..modeling_utils import ModelMixin
@@ -510,6 +510,37 @@ def forward(
510510
return encoder_hidden_states, hidden_states
511511

512512

513+
class FluxPosEmbed(nn.Module):
514+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
515+
def __init__(self, theta: int, axes_dim: List[int]):
516+
super().__init__()
517+
self.theta = theta
518+
self.axes_dim = axes_dim
519+
520+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
521+
n_axes = ids.shape[-1]
522+
cos_out = []
523+
sin_out = []
524+
pos = ids.float()
525+
is_mps = ids.device.type == "mps"
526+
is_npu = ids.device.type == "npu"
527+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
528+
for i in range(n_axes):
529+
cos, sin = get_1d_rotary_pos_embed(
530+
self.axes_dim[i],
531+
pos[:, i],
532+
theta=self.theta,
533+
repeat_interleave_real=True,
534+
use_real=True,
535+
freqs_dtype=freqs_dtype,
536+
)
537+
cos_out.append(cos)
538+
sin_out.append(sin)
539+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
540+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
541+
return freqs_cos, freqs_sin
542+
543+
513544
class FluxTransformer2DModel(
514545
ModelMixin,
515546
ConfigMixin,

0 commit comments

Comments
 (0)