@@ -223,17 +223,23 @@ def __init__(
223223
224224    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
225225        hidden_states  =  torch .cat ([hidden_states [:, :, : self .stride [0 ] -  1 ], hidden_states ], dim = 2 )
226-         hidden_states  =  (
226+ 
227+         residual  =  (
227228            hidden_states .unflatten (4 , (- 1 , self .stride [2 ]))
228229            .unflatten (3 , (- 1 , self .stride [1 ]))
229230            .unflatten (2 , (- 1 , self .stride [0 ]))
230231        )
231-         hidden_states  =  hidden_states .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
232+         residual  =  residual .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
233+         residual  =  residual .unflatten (1 , (- 1 , self .group_size ))
234+         residual  =  residual .mean (dim = 2 )
232235
233-         residual  =  hidden_states 
234-         hidden_states  =  hidden_states .unflatten (1 , (- 1 , self .group_size ))
235-         hidden_states  =  hidden_states .mean (dim = 2 )
236236        hidden_states  =  self .conv (hidden_states )
237+         hidden_states  =  (
238+             hidden_states .unflatten (4 , (- 1 , self .stride [2 ]))
239+             .unflatten (3 , (- 1 , self .stride [1 ]))
240+             .unflatten (2 , (- 1 , self .stride [0 ]))
241+         )
242+         hidden_states  =  hidden_states .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
237243        hidden_states  =  hidden_states  +  residual 
238244
239245        return  hidden_states 
0 commit comments