@@ -1142,32 +1142,38 @@ def get_1d_rotary_pos_embed(
11421142 """
11431143 assert dim % 2 == 0
11441144
1145- if isinstance (pos , int ):
1146- pos = torch .arange (pos )
1147- if isinstance (pos , np .ndarray ):
1148- pos = torch .from_numpy (pos ) # type: ignore # [S]
1145+ # Handle both batched [B, S] and un-batched [S] inputs
1146+ if pos .ndim == 1 :
1147+ pos = pos .unsqueeze (0 ) # Add a batch dimension if missing
11491148
11501149 theta = theta * ntk_factor
11511150 freqs = (
11521151 1.0 / (theta ** (torch .arange (0 , dim , 2 , dtype = freqs_dtype , device = pos .device ) / dim )) / linear_factor
1153- ) # [D/2]
1154- freqs = torch .outer (pos , freqs ) # type: ignore # [S, D/2]
1152+ ) # Shape: [D/2]
1153+
1154+ # Replace torch.outer with broadcasted multiplication
1155+ # Old: freqs = torch.outer(pos, freqs) # Shape: [S, D/2]
1156+ # New: pos is [B, S], freqs is [D/2]. Unsqueeze pos to [B, S, 1] for broadcasting.
1157+ freqs = pos .unsqueeze (- 1 ) * freqs # Shape: [B, S, D/2]
1158+
11551159 is_npu = freqs .device .type == "npu"
11561160 if is_npu :
11571161 freqs = freqs .float ()
1162+
11581163 if use_real and repeat_interleave_real :
11591164 # flux, hunyuan-dit, cogvideox
1160- freqs_cos = freqs .cos ().repeat_interleave (2 , dim = 1 , output_size = freqs .shape [1 ] * 2 ).float () # [S, D]
1161- freqs_sin = freqs .sin ().repeat_interleave (2 , dim = 1 , output_size = freqs .shape [1 ] * 2 ).float () # [S, D]
1165+ # Use dim=-1 for robust interleaving on the feature dimension
1166+ freqs_cos = freqs .cos ().repeat_interleave (2 , dim = - 1 ) # Shape: [B, S, D]
1167+ freqs_sin = freqs .sin ().repeat_interleave (2 , dim = - 1 ) # Shape: [B, S, D]
11621168 return freqs_cos , freqs_sin
11631169 elif use_real :
11641170 # stable audio, allegro
1165- freqs_cos = torch .cat ([freqs .cos (), freqs .cos ()], dim = - 1 ).float () # [ S, D]
1166- freqs_sin = torch .cat ([freqs .sin (), freqs .sin ()], dim = - 1 ).float () # [ S, D]
1171+ freqs_cos = torch .cat ([freqs .cos (), freqs .cos ()], dim = - 1 ).float () # Shape: [B, S, D]
1172+ freqs_sin = torch .cat ([freqs .sin (), freqs .sin ()], dim = - 1 ).float () # Shape: [B, S, D]
11671173 return freqs_cos , freqs_sin
11681174 else :
11691175 # lumina
1170- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64 # [ S, D/2]
1176+ freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # Shape: [B, S, D/2]
11711177 return freqs_cis
11721178
11731179
@@ -1246,26 +1252,35 @@ def __init__(self, theta: int, axes_dim: List[int]):
12461252 self .axes_dim = axes_dim
12471253
12481254 def forward (self , ids : torch .Tensor ) -> torch .Tensor :
1255+ # ids is now expected to be [B, S, n_axes]
12491256 n_axes = ids .shape [- 1 ]
12501257 cos_out = []
12511258 sin_out = []
12521259 pos = ids .float ()
12531260 is_mps = ids .device .type == "mps"
12541261 is_npu = ids .device .type == "npu"
12551262 freqs_dtype = torch .float32 if (is_mps or is_npu ) else torch .float64
1263+
12561264 for i in range (n_axes ):
12571265 cos , sin = get_1d_rotary_pos_embed (
12581266 self .axes_dim [i ],
1259- pos [:, i ],
1267+ pos [:, :, i ], # Correct slicing for batched input
12601268 theta = self .theta ,
12611269 repeat_interleave_real = True ,
12621270 use_real = True ,
12631271 freqs_dtype = freqs_dtype ,
12641272 )
12651273 cos_out .append (cos )
12661274 sin_out .append (sin )
1275+
12671276 freqs_cos = torch .cat (cos_out , dim = - 1 ).to (ids .device )
12681277 freqs_sin = torch .cat (sin_out , dim = - 1 ).to (ids .device )
1278+
1279+ # Squeeze the batch dim if the original input was unbatched
1280+ if ids .ndim == 2 :
1281+ freqs_cos = freqs_cos .squeeze (0 )
1282+ freqs_sin = freqs_sin .squeeze (0 )
1283+
12691284 return freqs_cos , freqs_sin
12701285
12711286
0 commit comments