@@ -91,17 +91,15 @@ def inner(t, *args, **kwargs):
9191
9292@typecheck
9393def lens_to_mask (
94- lens : Int ['b n' ] | Int [' b' ]
95- ) -> Bool ['b m' ]:
94+ lens : Int ['b ...' ],
95+ max_len : int | None = None
96+ ) -> Bool ['... m' ]:
9697
9798 device = lens .device
98-
99- if lens .ndim == 2 :
100- lens = reduce (lens , 'b m -> b' , 'sum' )
101-
102- max_len = lens .amax ()
99+ if not exists (max_len ):
100+ max_len = lens .amax ()
103101 arange = torch .arange (max_len , device = device )
104- return einx .less ('m, b -> b m' , arange , lens )
102+ return einx .less ('m, ... -> ... m' , arange , lens )
105103
106104@typecheck
107105def mean_pool_with_lens (
@@ -133,50 +131,70 @@ def mean_pool_with_lens(
133131def repeat_consecutive_with_lens (
134132 feats : Float ['b n ...' ] | Bool ['b n' ],
135133 lens : Int ['b n' ],
136- max_length : int | None = None ,
137- return_mask = False
138- ) -> Float ['b m d' ] | Bool ['b m' ] | Tuple [Float ['b m d' ] | Bool ['b m' ], Bool ['b m' ]]:
134+ ) -> Float ['b m ...' ] | Bool ['b m' ]:
139135
140136 is_bool = feats .dtype == torch .bool
137+ feats = feats .float ()
138+
141139 device = feats .device
142140
143- # derive arange from the max length
141+ batch , seq , * dims = feats .shape
142+
143+ # get mask from lens
144+
145+ mask = lens_to_mask (lens )
146+
147+ # derive arange
144148
145- total_lens = reduce (lens , 'b n -> b' , 'sum' )
149+ window_size = mask .shape [- 1 ]
150+ arange = torch .arange (window_size , device = device )
146151
147- if not exists (max_length ):
148- max_length = total_lens .amax ()
152+ cumsum_len = lens .cumsum (dim = - 1 )
153+ offsets = F .pad (cumsum_len , (1 , - 1 ), value = 0 )
154+ indices = einx .add ('w, b n -> b n w' , arange , offsets )
149155
150- arange = torch . arange ( max_length , device = device )
156+ # create output tensor + a sink position on the very right (index max_len )
151157
152- # get packed atom mask from the total lengths
158+ total_lens = lens .sum (dim = - 1 )
159+ max_len = total_lens .amax ()
153160
154- mask = lens_to_mask ( total_lens )
161+ output = torch . zeros (( batch , max_len + 1 , * dims ), device = device )
155162
156- lens = F .pad (lens , (1 , 0 ), value = 0 )
157- cumsum_lens = lens .cumsum (dim = - 1 )
158- left_index , right_index = cumsum_lens [:, :- 1 ], cumsum_lens [:, 1 :]
163+ indices .masked_fill_ (~ mask , max_len ) # scatter to sink position for padding
164+ indices = rearrange (indices , 'b n w -> b (n w)' )
159165
160- # derive the mask for consecutives per feat
166+ feats = repeat ( feats , 'b n ... -> b (n w) ...' , w = window_size )
161167
162- left_mask = einx .greater_equal ('m, b n -> b n m' , arange , left_index )
163- right_mask = einx .less ('m, b n -> b n m' , arange , right_index )
168+ # scatter
164169
165- consecutive_mask = left_mask & right_mask
170+ output = einx . set_at ( 'b [m] ..., b nw, b nw ... -> b [m] ...' , output , indices , feats )
166171
167- # now broadcast and sum for consecutive features
172+ # remove sink
168173
169- feats = einx .multiply ('b n ..., b n m -> b n m ...' , feats , consecutive_mask .float ())
170- feats = reduce (feats , 'b n m ... -> b m ...' , 'sum' )
174+ output = output [:, :- 1 ]
171175
172176 if is_bool :
173- feats = feats .bool ()
177+ output = output .bool ()
174178
175- if not return_mask :
176- return feats
179+ return output
177180
178- mask = mask [:, :max_length ]
179- return feats , mask
181+ def repeat_pairwise_consecutive_with_lens (
182+ feats : Float ['b n n dp' ],
183+ lens : Int ['b n' ]
184+ ) -> Float ['b m m dp' ]:
185+
186+ repeated_lens = repeat (lens , 'b ... -> (b r) ...' , r = feats .shape [1 ])
187+ feats , ps = pack_one (feats , '* n dp' )
188+ feats = repeat_consecutive_with_lens (feats , repeated_lens )
189+ feats = unpack_one (feats , ps , '* n dp' )
190+
191+ feats = rearrange (feats , 'b i j dp -> b j i dp' )
192+ repeated_lens = repeat (lens , 'b ... -> (b r) ...' , r = feats .shape [1 ])
193+ feats , ps = pack_one (feats , '* n dp' )
194+ feats = repeat_consecutive_with_lens (feats , repeated_lens )
195+ feats = unpack_one (feats , ps , '* n dp' )
196+ feats = rearrange (feats , 'b j i dp -> b i j dp' )
197+ return feats
180198
181199# linear and outer sum
182200# for single repr -> pairwise pattern throughout this architecture
@@ -1607,19 +1625,7 @@ def forward(
16071625 if is_unpacked_repr :
16081626 pairwise_repr_cond = repeat (pairwise_repr_cond , 'b i j dp -> b (i w1) (j w2) dp' , w1 = w , w2 = w )
16091627 else :
1610- # todo - fix by doing a specialized fn for this
1611-
1612- repeated_residue_atom_lens = repeat (residue_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr_cond .shape [1 ])
1613- pairwise_repr_cond , ps = pack_one (pairwise_repr_cond , '* n dp' )
1614- pairwise_repr_cond = repeat_consecutive_with_lens (pairwise_repr_cond , repeated_residue_atom_lens )
1615- pairwise_repr_cond = unpack_one (pairwise_repr_cond , ps , '* n dp' )
1616-
1617- pairwise_repr_cond = rearrange (pairwise_repr_cond , 'b i j dp -> b j i dp' )
1618- repeated_residue_atom_lens = repeat (residue_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr_cond .shape [1 ])
1619- pairwise_repr_cond , ps = pack_one (pairwise_repr_cond , '* n dp' )
1620- pairwise_repr_cond = repeat_consecutive_with_lens (pairwise_repr_cond , repeated_residue_atom_lens )
1621- pairwise_repr_cond = unpack_one (pairwise_repr_cond , ps , '* n dp' )
1622- pairwise_repr_cond = rearrange (pairwise_repr_cond , 'b j i dp -> b i j dp' )
1628+ pairwise_repr_cond = repeat_pairwise_consecutive_with_lens (pairwise_repr_cond , residue_atom_lens )
16231629
16241630 atompair_feats = pairwise_repr_cond + atompair_feats
16251631
@@ -2834,7 +2840,8 @@ def forward(
28342840
28352841 # handle atom mask
28362842
2837- atom_mask = lens_to_mask (residue_atom_lens )
2843+ total_atoms = residue_atom_lens .sum (dim = - 1 )
2844+ atom_mask = lens_to_mask (total_atoms )
28382845 atom_mask = atom_mask [:, :atom_seq_len ]
28392846
28402847 # handle offsets for residue atom indices
@@ -2896,7 +2903,8 @@ def forward(
28962903 # pairwise mask
28972904
28982905 if self .packed_atom_repr :
2899- mask = lens_to_mask (residue_atom_lens )
2906+ total_atoms = residue_atom_lens .sum (dim = - 1 )
2907+ mask = lens_to_mask (total_atoms )
29002908 mask = mask [:, :seq_len ]
29012909 else :
29022910 mask = reduce (atom_mask , 'b (n w) -> b n' , w = w , reduction = 'any' )
0 commit comments