@@ -383,6 +383,28 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
383383 # compute the shortcut part
384384 shortcut = rearrange (x , "b (r2 r3 c) f h w -> b c f (h r2) (w r3)" , r2 = 2 , r3 = 2 )
385385 shortcut = shortcut .repeat_interleave (repeats = self .repeats // 2 , dim = 1 )
386+ elif feat_cache is None and x .shape [2 ] > 1 :
387+ # Multi-frame input without cache: first frame only spatial upsample, rest frames do spatio-temporal upsample
388+ # Separate first frame and remaining frames
389+ h_first = h [:, :, :1 , :, :] # first frame
390+ h_rest = h [:, :, 1 :, :, :] # remaining frames
391+ x_first = x [:, :, :1 , :, :]
392+ x_rest = x [:, :, 1 :, :, :]
393+
394+ # First frame: only spatial upsample
395+ h_first = rearrange (h_first , "b (r2 r3 c) f h w -> b c f (h r2) (w r3)" , r2 = 2 , r3 = 2 )
396+ h_first = h_first [:, : h_first .shape [1 ] // 2 ]
397+ shortcut_first = rearrange (x_first , "b (r2 r3 c) f h w -> b c f (h r2) (w r3)" , r2 = 2 , r3 = 2 )
398+ shortcut_first = shortcut_first .repeat_interleave (repeats = self .repeats // 2 , dim = 1 )
399+ out_first = h_first + shortcut_first
400+
401+ # Remaining frames: spatio-temporal upsample
402+ h_rest = rearrange (h_rest , "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)" , r1 = r1 , r2 = 2 , r3 = 2 )
403+ shortcut_rest = rearrange (x_rest , "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)" , r1 = r1 , r2 = 2 , r3 = 2 )
404+ shortcut_rest = shortcut_rest .repeat_interleave (repeats = self .repeats , dim = 1 )
405+ out_rest = h_rest + shortcut_rest
406+
407+ return torch .cat ([out_first , out_rest ], dim = 2 )
386408 else :
387409 h = rearrange (h , "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)" , r1 = r1 , r2 = 2 , r3 = 2 )
388410 # compute the shortcut part
@@ -870,8 +892,9 @@ def tile_parallel_spatial_tiled_decode(self, z: torch.Tensor):
870892 decoded_metas .append (torch .tensor ([ri , rj , pad_w , pad_h ], device = z .device , dtype = torch .int64 ))
871893
872894 while len (decoded_tiles ) < tiles_per_rank :
895+ T_out = decoded_tiles [0 ].shape [2 ] if len (decoded_tiles ) > 0 else (T - 1 )* self .ffactor_temporal + 1
873896 zero_tile = torch .zeros (
874- [1 , 3 , ( T - 1 ) * self . ffactor_temporal + 1 , self .tile_sample_min_size , self .tile_sample_min_size ],
897+ [1 , 3 , T_out , self .tile_sample_min_size , self .tile_sample_min_size ],
875898 device = dec .device ,
876899 dtype = dec .dtype
877900 )
@@ -891,6 +914,7 @@ def tile_parallel_spatial_tiled_decode(self, z: torch.Tensor):
891914
892915 dist .all_gather (tiles_gather_list , decoded_tiles , group = get_parallel_state ().sp_group )
893916 dist .all_gather (metas_gather_list , decoded_metas , group = get_parallel_state ().sp_group )
917+ dist .barrier ()
894918
895919 if rank != 0 :
896920 return torch .empty (0 , device = z .device )
0 commit comments