@@ -89,6 +89,20 @@ def inner(t, *args, **kwargs):
8989
9090# packed atom representation functions
9191
92+ @typecheck
93+ def lens_to_mask (
94+ lens : Int ['b n' ] | Int [' b' ]
95+ ) -> Bool ['b m' ]:
96+
97+ device = lens .device
98+
99+ if lens .ndim == 2 :
100+ lens = reduce (lens , 'b m -> b' , 'sum' )
101+
102+ max_len = lens .amax ()
103+ arange = torch .arange (max_len , device = device )
104+ return einx .less ('m, b -> b m' , arange , lens )
105+
92106@typecheck
93107def mean_pool_with_lens (
94108 feats : Float ['b m d' ],
@@ -115,6 +129,51 @@ def mean_pool_with_lens(
115129 avg = einx .where ('b n, b n d, -> b n d' , mask , avg , 0. )
116130 return avg
117131
132+ @typecheck
133+ def repeat_consecutive_with_lens (
134+ feats : Float ['b n d' ],
135+ lens : Int ['b n' ],
136+ max_length : int | None = None ,
137+ return_mask = False
138+ ) -> Float ['b m d' ] | Tuple [Float ['b m d' ], Bool ['b m' ]]:
139+
140+ device = feats .device
141+
142+ # derive arange from the max length
143+
144+ total_lens = reduce (lens , 'b n -> b' , 'sum' )
145+
146+ if not exists (max_length ):
147+ max_length = total_lens .amax ()
148+
149+ arange = torch .arange (max_length , device = device )
150+
151+ # get packed atom mask from the total lengths
152+
153+ mask = lens_to_mask (total_lens )
154+
155+ lens = F .pad (lens , (1 , 0 ), value = 0 )
156+ cumsum_lens = lens .cumsum (dim = - 1 )
157+ left_index , right_index = cumsum_lens [:, :- 1 ], cumsum_lens [:, 1 :]
158+
159+ # derive the mask for consecutives per feat
160+
161+ left_mask = einx .greater_equal ('m, b n -> b n m' , arange , left_index )
162+ right_mask = einx .less ('m, b n -> b n m' , arange , right_index )
163+
164+ consecutive_mask = left_mask & right_mask
165+
166+ # now broadcast and sum for consecutive features
167+
168+ feats = einx .multiply ('b n d, b n m -> b n m d' , feats , consecutive_mask .float ())
169+ feats = reduce (feats , 'b n m d -> b m d' , 'sum' )
170+
171+ if not return_mask :
172+ return feats
173+
174+ mask = mask [:, :max_length ]
175+ return feats , mask
176+
118177# linear and outer sum
119178# for single repr -> pairwise pattern throughout this architecture
120179
0 commit comments