2323from collections import namedtuple
2424
2525import torch
26- from torch import nn
26+ from torch import nn , sigmoid
2727from torch import Tensor
2828import torch .nn .functional as F
2929
@@ -80,15 +80,15 @@ def unpack_one(t, ps, pattern):
8080
8181# Loss functions
8282
83- def smoothlddtloss (
83+ @typecheck
84+ def calc_smooth_lddt_loss (
8485 denoised : Float ['b m 3' ],
8586 ground_truth : Float ['b m 3' ],
8687 is_rna_per_atom : Float ['b m' ],
8788 is_dna_per_atom : Float ['b m' ]
88- ) -> Float ['b' ]:
89- from torch import sigmoid
89+ ) -> Float [' b' ]:
9090
91- m = is_rna_per_atom .shape [- 1 ]
91+ m , device = is_rna_per_atom .shape [- 1 ], denoised . device
9292
9393 dx_denoised = torch .cdist (denoised , denoised )
9494 dx_gt = torch .cdist (ground_truth , ground_truth )
@@ -102,10 +102,11 @@ def smoothlddtloss(
102102 mask = einx .multiply ('b i, b j -> b i j' , is_nuc , is_nuc )
103103 c = (dx_gt < 30 ) * mask + (dx_gt < 15 ) * (1 - mask )
104104
105- num = einx .sum ('b [...]' , c * eps * (1 - torch .eye (m ))) / (m ** 2 - m )
106- den = einx .sum ('b [...]' , c * (1 - torch .eye (m ))) / (m ** 2 - m )
107-
108- return 1 - num / den
105+ eye = torch .eye (m , device = device )
106+ num = einx .sum ('b [...]' , c * eps * (1 - eye )) / (m ** 2 - m )
107+ den = einx .sum ('b [...]' , c * (1 - eye )) / (m ** 2 - m )
108+
109+ return 1. - num / den
109110
110111# linear and outer sum
111112# for single repr -> pairwise pattern throughout this architecture
@@ -1699,6 +1700,8 @@ def forward(
16991700 normalized_atom_pos : Float ['b m 3' ],
17001701 atom_mask : Bool ['b m' ],
17011702 return_denoised_pos = False ,
1703+ additional_residue_feats : Float ['b n rf' ] | None = None ,
1704+ add_smooth_lddt_loss = False ,
17021705 ** network_condition_kwargs
17031706 ) -> Float ['' ] | Tuple [Float ['' ], Float ['b m 3' ]]:
17041707
@@ -1726,6 +1729,22 @@ def forward(
17261729
17271730 loss = losses .mean ()
17281731
1732+ if add_smooth_lddt_loss :
1733+ assert exists (additional_residue_feats )
1734+ w = self .net .atoms_per_window
1735+
1736+ is_dna , is_rna = additional_residue_feats [..., 7 ], additional_residue_feats [..., 8 ]
1737+ atom_is_dna , atom_is_rna = tuple (repeat (t , 'b n -> b (n w)' , w = w ) for t in (is_dna , is_rna ))
1738+
1739+ smooth_lddt_loss = calc_smooth_lddt_loss (
1740+ denoised ,
1741+ normalized_atom_pos ,
1742+ atom_is_dna ,
1743+ atom_is_rna
1744+ ).mean ()
1745+
1746+ loss = loss + smooth_lddt_loss
1747+
17291748 if not return_denoised_pos :
17301749 return loss
17311750
@@ -2354,7 +2373,13 @@ def forward(
23542373 # otherwise, noise and make it learn to denoise
23552374
23562375 if exists (atom_pos ):
2357- diffusion_loss , denoised_atom_pos = self .edm (atom_pos , return_denoised_pos = True , ** diffusion_cond )
2376+ diffusion_loss , denoised_atom_pos = self .edm (
2377+ atom_pos ,
2378+ additional_residue_feats = additional_residue_feats ,
2379+ add_smooth_lddt_loss = True ,
2380+ return_denoised_pos = True ,
2381+ ** diffusion_cond
2382+ )
23582383
23592384 # calculate all logits and losses
23602385
0 commit comments