@@ -87,35 +87,33 @@ def inner(t, *args, **kwargs):
8787 return fn (t , * args , ** kwargs )
8888 return inner
8989
90- # Loss functions
90+ # packed atom representation functions
9191
9292@typecheck
93- def calc_smooth_lddt_loss (
94- denoised : Float ['b m 3' ],
95- ground_truth : Float ['b m 3' ],
96- is_rna_per_atom : Float ['b m' ],
97- is_dna_per_atom : Float ['b m' ]
98- ) -> Float [' b' ]:
99-
100- m , device = is_rna_per_atom .shape [- 1 ], denoised .device
101-
102- dx_denoised = torch .cdist (denoised , denoised )
103- dx_gt = torch .cdist (ground_truth , ground_truth )
104-
105- ddx = torch .abs (dx_gt - dx_denoised )
106- eps = 0.25 * (
107- sigmoid (0.5 - ddx ) + sigmoid (1 - ddx ) + sigmoid (2 - ddx ) + sigmoid (4 - ddx )
108- )
109-
110- is_nuc = is_rna_per_atom + is_dna_per_atom
111- mask = einx .multiply ('b i, b j -> b i j' , is_nuc , is_nuc )
112- c = (dx_gt < 30 ) * mask + (dx_gt < 15 ) * (1 - mask )
113-
114- eye = torch .eye (m , device = device )
115- num = einx .sum ('b [...]' , c * eps * (1 - eye )) / (m ** 2 - m )
116- den = einx .sum ('b [...]' , c * (1 - eye )) / (m ** 2 - m )
93+ def mean_pool_with_lens (
94+ feats : Float ['b m d' ],
95+ lens : Int ['b n' ]
96+ ) -> Float ['b n d' ]:
97+
98+ seq_len = feats .shape [1 ]
99+
100+ mask = lens > 0
101+ assert (lens .sum (dim = - 1 ) <= seq_len ).all (), 'one of the lengths given exceeds the total sequence length of the features passed in'
102+
103+ cumsum_feats = feats .cumsum (dim = 1 )
104+ cumsum_feats = F .pad (cumsum_feats , (0 , 0 , 1 , 0 ), value = 0. )
105+
106+ cumsum_indices = lens .cumsum (dim = 1 )
107+ cumsum_indices = F .pad (cumsum_indices , (1 , 0 ), value = 0 )
108+
109+ sel_cumsum = einx .get_at ('b [m] d, b n -> b n d' , cumsum_feats , cumsum_indices )
110+
111+ # subtract cumsum at one index from the previous one
112+ summed = sel_cumsum [:, 1 :] - sel_cumsum [:, :- 1 ]
117113
118- return 1. - num / den
114+ avg = einx .divide ('b n d, b n' , summed , lens .clamp (min = 1 ))
115+ avg = einx .where ('b n, b n d, -> b n d' , mask , avg , 0. )
116+ return avg
119117
120118# linear and outer sum
121119# for single repr -> pairwise pattern throughout this architecture
0 commit comments