@@ -1238,37 +1238,6 @@ def apply_1d_rope(tokens, pos, cos, sin):
12381238 return x
12391239
12401240
1241- class FluxPosEmbed (nn .Module ):
1242- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
1243- def __init__ (self , theta : int , axes_dim : List [int ]):
1244- super ().__init__ ()
1245- self .theta = theta
1246- self .axes_dim = axes_dim
1247-
1248- def forward (self , ids : torch .Tensor ) -> torch .Tensor :
1249- n_axes = ids .shape [- 1 ]
1250- cos_out = []
1251- sin_out = []
1252- pos = ids .float ()
1253- is_mps = ids .device .type == "mps"
1254- is_npu = ids .device .type == "npu"
1255- freqs_dtype = torch .float32 if (is_mps or is_npu ) else torch .float64
1256- for i in range (n_axes ):
1257- cos , sin = get_1d_rotary_pos_embed (
1258- self .axes_dim [i ],
1259- pos [:, i ],
1260- theta = self .theta ,
1261- repeat_interleave_real = True ,
1262- use_real = True ,
1263- freqs_dtype = freqs_dtype ,
1264- )
1265- cos_out .append (cos )
1266- sin_out .append (sin )
1267- freqs_cos = torch .cat (cos_out , dim = - 1 ).to (ids .device )
1268- freqs_sin = torch .cat (sin_out , dim = - 1 ).to (ids .device )
1269- return freqs_cos , freqs_sin
1270-
1271-
12721241class TimestepEmbedding (nn .Module ):
12731242 def __init__ (
12741243 self ,
@@ -2619,3 +2588,13 @@ def forward(self, image_embeds: List[torch.Tensor]):
26192588 projected_image_embeds .append (image_embed )
26202589
26212590 return projected_image_embeds
2591+
2592+
2593+ class FluxPosEmbed (nn .Module ):
2594+ def __new__ (cls , * args , ** kwargs ):
2595+ deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
2596+ deprecate ("FluxPosEmbed" , "1.0.0" , deprecation_message )
2597+
2598+ from .transformers .transformer_flux import FluxPosEmbed
2599+
2600+ return FluxPosEmbed (* args , ** kwargs )
0 commit comments