@@ -223,17 +223,23 @@ def __init__(
223223
224224 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
225225 hidden_states = torch .cat ([hidden_states [:, :, : self .stride [0 ] - 1 ], hidden_states ], dim = 2 )
226- hidden_states = (
226+
227+ residual = (
227228 hidden_states .unflatten (4 , (- 1 , self .stride [2 ]))
228229 .unflatten (3 , (- 1 , self .stride [1 ]))
229230 .unflatten (2 , (- 1 , self .stride [0 ]))
230231 )
231- hidden_states = hidden_states .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
232+ residual = residual .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
233+ residual = residual .unflatten (1 , (- 1 , self .group_size ))
234+ residual = residual .mean (dim = 2 )
232235
233- residual = hidden_states
234- hidden_states = hidden_states .unflatten (1 , (- 1 , self .group_size ))
235- hidden_states = hidden_states .mean (dim = 2 )
236236 hidden_states = self .conv (hidden_states )
237+ hidden_states = (
238+ hidden_states .unflatten (4 , (- 1 , self .stride [2 ]))
239+ .unflatten (3 , (- 1 , self .stride [1 ]))
240+ .unflatten (2 , (- 1 , self .stride [0 ]))
241+ )
242+ hidden_states = hidden_states .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
237243 hidden_states = hidden_states + residual
238244
239245 return hidden_states
0 commit comments