@@ -1510,7 +1510,7 @@ def __init__(self,
1510
1510
operations = None ,
1511
1511
):
1512
1512
1513
- super ().__init__ (model_type = 't2v ' , patch_size = patch_size , text_len = text_len , in_dim = in_dim , dim = dim , ffn_dim = ffn_dim , freq_dim = freq_dim , text_dim = text_dim , out_dim = out_dim , num_heads = num_heads , num_layers = num_layers , window_size = window_size , qk_norm = qk_norm , cross_attn_norm = cross_attn_norm , eps = eps , flf_pos_embed_token_number = flf_pos_embed_token_number , wan_attn_block_class = WanAttentionBlockAudio , image_model = image_model , device = device , dtype = dtype , operations = operations )
1513
+ super ().__init__ (model_type = 'i2v ' , patch_size = patch_size , text_len = text_len , in_dim = 36 , dim = dim , ffn_dim = ffn_dim , freq_dim = freq_dim , text_dim = text_dim , out_dim = out_dim , num_heads = num_heads , num_layers = num_layers , window_size = window_size , qk_norm = qk_norm , cross_attn_norm = cross_attn_norm , eps = eps , flf_pos_embed_token_number = flf_pos_embed_token_number , wan_attn_block_class = WanAttentionBlockAudio , image_model = image_model , device = device , dtype = dtype , operations = operations )
1514
1514
1515
1515
self .audio_proj = AudioProjModel (seq_len = 8 , blocks = 5 , channels = 1280 , intermediate_dim = 512 , output_dim = 1536 , context_tokens = audio_token_num , dtype = dtype , device = device , operations = operations )
1516
1516
@@ -1539,6 +1539,12 @@ def forward_orig(
1539
1539
e0 = self .time_projection (e ).unflatten (2 , (6 , self .dim ))
1540
1540
1541
1541
if reference_latent is not None :
1542
+ if reference_latent .shape [1 ] < 36 :
1543
+ padding_needed = 36 - reference_latent .shape [1 ]
1544
+ padding = torch .zeros (reference_latent .shape [0 ], padding_needed , * reference_latent .shape [2 :],
1545
+ device = reference_latent .device , dtype = reference_latent .dtype )
1546
+ reference_latent = torch .cat ([padding , reference_latent ], dim = 1 ) # pad at beginning like c_concat
1547
+
1542
1548
ref = self .patch_embedding (reference_latent .float ()).to (x .dtype )
1543
1549
ref = ref .flatten (2 ).transpose (1 , 2 )
1544
1550
freqs_ref = self .rope_encode (reference_latent .shape [- 3 ], reference_latent .shape [- 2 ], reference_latent .shape [- 1 ], t_start = time , device = x .device , dtype = x .dtype )
@@ -1548,7 +1554,7 @@ def forward_orig(
1548
1554
1549
1555
# context
1550
1556
context = self .text_embedding (context )
1551
- context_img_len = None
1557
+ context_img_len = 0
1552
1558
1553
1559
if audio_embed is not None :
1554
1560
if reference_latent is not None :
0 commit comments