@@ -48,20 +48,6 @@ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], thet
4848        self .theta  =  theta 
4949
5050    def  forward (self , frame_indices : torch .Tensor , height : int , width : int , device : torch .device ):
51-         # This is from the original code.  We don't call _forward for each batch index because we know that 
52-         # each batch has the same frame indices. However, it may be possible that the frame indices don't 
53-         # always be the same for every item in a batch (such as in training). We cannot use the original 
54-         # implementation because our `apply_rotary_emb` function broadcasts across the batch dim, so we'd 
55-         # need to first implement another attention processor or modify the existing one with different apply_rotary_emb 
56-         # frame_indices = frame_indices.unbind(0) 
57-         # freqs = [self._forward(f, height, width, device) for f in frame_indices] 
58-         # freqs_cos, freqs_sin = zip(*freqs) 
59-         # freqs_cos = torch.stack(freqs_cos, dim=0)  # [B, W * H * T, D / 2] 
60-         # freqs_sin = torch.stack(freqs_sin, dim=0)  # [B, W * H * T, D / 2] 
61-         # return freqs_cos, freqs_sin 
62-         return  self ._forward (frame_indices , height , width , device )
63- 
64-     def  _forward (self , frame_indices , height , width , device ):
6551        height  =  height  //  self .patch_size 
6652        width  =  width  //  self .patch_size 
6753        grid  =  torch .meshgrid (
0 commit comments