@@ -151,10 +151,7 @@ def repeat_consecutive_with_lens(
151151 lens : Int ['b n' ],
152152) -> Float ['b m ...' ] | Bool ['b m' ]:
153153
154- is_bool = feats .dtype == torch .bool
155- feats = feats .float ()
156-
157- device = feats .device
154+ device , dtype = feats .device , feats .dtype
158155
159156 batch , seq , * dims = feats .shape
160157
@@ -174,25 +171,38 @@ def repeat_consecutive_with_lens(
174171 # create output tensor + a sink position on the very right (index max_len)
175172
176173 total_lens = lens .sum (dim = - 1 )
174+ output_mask = lens_to_mask (total_lens )
175+
177176 max_len = total_lens .amax ()
178177
179- output = torch .zeros ((batch , max_len + 1 , * dims ), device = device )
178+ output_indices = torch .zeros ((batch , max_len + 1 ), device = device , dtype = torch . long )
180179
181180 indices .masked_fill_ (~ mask , max_len ) # scatter to sink position for padding
182181 indices = rearrange (indices , 'b n w -> b (n w)' )
183182
184- feats = repeat (feats , 'b n ... -> b (n w) ...' , w = window_size )
185-
186183 # scatter
187184
188- output = einx .set_at ('b [m] ..., b nw, b nw ... -> b [m] ...' , output , indices , feats )
185+ seq_arange = torch .arange (seq , device = device )
186+ seq_arange = repeat (seq_arange , 'n -> (n w)' , w = window_size )
187+
188+ output_indices = einx .set_at ('b [m], b nw, nw -> b [m]' , output_indices , indices , seq_arange )
189189
190190 # remove sink
191191
192- output = output [:, :- 1 ]
192+ output_indices = output_indices [:, :- 1 ]
193+
194+ # gather
195+
196+ output = einx .get_at ('b [n] ..., b m -> b m ...' , feats , output_indices )
197+
198+ # final mask
199+
200+ mask_value = False if dtype == torch .bool else 0
193201
194- if is_bool :
195- output = output .bool ()
202+ output = einx .where (
203+ 'b n, b n ..., -> b n ...' ,
204+ output_mask , output , mask_value
205+ )
196206
197207 return output
198208
0 commit comments