@@ -78,6 +78,35 @@ def pack_one(t, pattern):
7878def unpack_one (t , ps , pattern ):
7979 return unpack (t , ps , pattern )[0 ]
8080
81+ # Loss functions
82+
83+ def smoothlddtloss (
84+ denoised : Float ['b m 3' ],
85+ ground_truth : Float ['b m 3' ],
86+ is_rna_per_atom : Float ['b m' ],
87+ is_dna_per_atom : Float ['b m' ]
88+ ) -> Float ['b' ]:
89+ from torch import sigmoid
90+
91+ m = is_rna_per_atom .shape [- 1 ]
92+
93+ dx_denoised = torch .cdist (denoised , denoised )
94+ dx_gt = torch .cdist (ground_truth , ground_truth )
95+
96+ ddx = torch .abs (dx_gt - dx_denoised )
97+ eps = 0.25 * (
98+ sigmoid (0.5 - ddx ) + sigmoid (1 - ddx ) + sigmoid (2 - ddx ) + sigmoid (4 - ddx )
99+ )
100+
101+ is_nuc = is_rna_per_atom + is_dna_per_atom
102+ mask = einx .multiply ('b i, b j -> b i j' , is_nuc , is_nuc )
103+ c = (dx_gt < 30 ) * mask + (dx_gt < 15 ) * (1 - mask )
104+
105+ num = einx .sum ('b [...]' , c * eps * (1 - torch .eye (m ))) / (m ** 2 - m )
106+ den = einx .sum ('b [...]' , c * (1 - torch .eye (m ))) / (m ** 2 - m )
107+
108+ return 1 - num / den
109+
81110# linear and outer sum
82111# for single repr -> pairwise pattern throughout this architecture
83112
0 commit comments