@@ -100,11 +100,15 @@ def forward(
100100 embedded_timestep = self .linear_2 (embedded_timestep )
101101
102102 if temb is not None :
103- embedded_timestep = embedded_timestep + temb [: , : 2 * self .embedding_dim ]
103+ embedded_timestep = embedded_timestep + temb [... , : 2 * self .embedding_dim ]
104104
105- shift , scale = embedded_timestep .chunk (2 , dim = 1 )
105+ shift , scale = embedded_timestep .chunk (2 , dim = - 1 )
106106 hidden_states = self .norm (hidden_states )
107- hidden_states = hidden_states * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
107+
108+ if embedded_timestep .ndim == 2 :
109+ shift , scale = (x .unsqueeze (1 ) for x in (shift , scale ))
110+
111+ hidden_states = hidden_states * (1 + scale ) + shift
108112 return hidden_states
109113
110114
@@ -135,9 +139,13 @@ def forward(
135139 if temb is not None :
136140 embedded_timestep = embedded_timestep + temb
137141
138- shift , scale , gate = embedded_timestep .chunk (3 , dim = 1 )
142+ shift , scale , gate = embedded_timestep .chunk (3 , dim = - 1 )
139143 hidden_states = self .norm (hidden_states )
140- hidden_states = hidden_states * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
144+
145+ if embedded_timestep .ndim == 2 :
146+ shift , scale , gate = (x .unsqueeze (1 ) for x in (shift , scale , gate ))
147+
148+ hidden_states = hidden_states * (1 + scale ) + shift
141149 return hidden_states , gate
142150
143151
@@ -255,19 +263,19 @@ def forward(
255263 # 1. Self Attention
256264 norm_hidden_states , gate = self .norm1 (hidden_states , embedded_timestep , temb )
257265 attn_output = self .attn1 (norm_hidden_states , image_rotary_emb = image_rotary_emb )
258- hidden_states = hidden_states + gate . unsqueeze ( 1 ) * attn_output
266+ hidden_states = hidden_states + gate * attn_output
259267
260268 # 2. Cross Attention
261269 norm_hidden_states , gate = self .norm2 (hidden_states , embedded_timestep , temb )
262270 attn_output = self .attn2 (
263271 norm_hidden_states , encoder_hidden_states = encoder_hidden_states , attention_mask = attention_mask
264272 )
265- hidden_states = hidden_states + gate . unsqueeze ( 1 ) * attn_output
273+ hidden_states = hidden_states + gate * attn_output
266274
267275 # 3. Feed Forward
268276 norm_hidden_states , gate = self .norm3 (hidden_states , embedded_timestep , temb )
269277 ff_output = self .ff (norm_hidden_states )
270- hidden_states = hidden_states + gate . unsqueeze ( 1 ) * ff_output
278+ hidden_states = hidden_states + gate * ff_output
271279
272280 return hidden_states
273281
@@ -513,7 +521,23 @@ def forward(
513521 hidden_states = hidden_states .flatten (1 , 3 ) # [B, T, H, W, C] -> [B, THW, C]
514522
515523 # 4. Timestep embeddings
516- temb , embedded_timestep = self .time_embed (hidden_states , timestep )
524+ if timestep .ndim == 1 :
525+ temb , embedded_timestep = self .time_embed (hidden_states , timestep )
526+ elif timestep .ndim == 5 :
527+ assert timestep .shape == (batch_size , 1 , num_frames , 1 , 1 ), (
528+ f"Expected timestep to have shape [B, 1, T, 1, 1], but got { timestep .shape } "
529+ )
530+ timestep = timestep .flatten ()
531+ temb , embedded_timestep = self .time_embed (hidden_states , timestep )
532+ # We can do this because num_frames == post_patch_num_frames, as p_t is 1
533+ temb , embedded_timestep = (
534+ x .view (batch_size , post_patch_num_frames , 1 , 1 , - 1 )
535+ .expand (- 1 , - 1 , post_patch_height , post_patch_width , - 1 )
536+ .flatten (1 , 3 )
537+ for x in (temb , embedded_timestep )
538+ ) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
539+ else :
540+ assert False
517541
518542 # 5. Transformer blocks
519543 for block in self .transformer_blocks :
0 commit comments