@@ -1741,6 +1741,8 @@ def __init__(
17411741
17421742 self .atom_pos_to_atom_feat = LinearNoBias (3 , dim_atom )
17431743
1744+ self .missing_atom_feat = nn .Parameter (torch .zeros (dim_atom ))
1745+
17441746 self .single_repr_to_atom_feat_cond = nn .Sequential (
17451747 nn .LayerNorm (dim_single ),
17461748 LinearNoBias (dim_single , dim_atom )
@@ -1839,7 +1841,8 @@ def forward(
18391841 pairwise_trunk : Float ['b n n dpt' ],
18401842 pairwise_rel_pos_feats : Float ['b n n dpr' ],
18411843 molecule_atom_lens : Int ['b n' ],
1842- atom_parent_ids : Int ['b m' ] | None = None
1844+ atom_parent_ids : Int ['b m' ] | None = None ,
1845+ missing_atom_mask : Bool ['b m' ]| None = None
18431846 ):
18441847 w = self .atoms_per_window
18451848 device = noised_atom_pos .device
@@ -1864,7 +1867,16 @@ def forward(
18641867
18651868 # the most surprising part of the paper; no geometric biases!
18661869
1867- atom_feats = self .atom_pos_to_atom_feat (noised_atom_pos ) + atom_feats
1870+ noised_atom_pos_feats = self .atom_pos_to_atom_feat (noised_atom_pos )
1871+
1872+ # for missing atoms, replace the noise atom pos features with a missing embedding
1873+
1874+ if exists (missing_atom_mask ):
1875+ noised_atom_pos_feats = einx .where ('b m, d, b m d -> b m d' , missing_atom_mask , self .missing_atom_feat , noised_atom_pos_feats )
1876+
1877+ # sum the noised atom position features to the atom features
1878+
1879+ atom_feats = noised_atom_pos_feats + atom_feats
18681880
18691881 # condition atom feats cond (cl) with single repr
18701882
@@ -2199,6 +2211,7 @@ def forward(
21992211 pairwise_trunk : Float ['b n n dpt' ],
22002212 pairwise_rel_pos_feats : Float ['b n n dpr' ],
22012213 molecule_atom_lens : Int ['b n' ],
2214+ missing_atom_mask : Bool ['b m' ] | None = None ,
22022215 atom_parent_ids : Int ['b m' ] | None = None ,
22032216 return_denoised_pos = False ,
22042217 is_molecule_types : Bool [f'b n { IS_MOLECULE_TYPES } ' ] | None = None ,
@@ -2227,6 +2240,7 @@ def forward(
22272240 network_condition_kwargs = dict (
22282241 atom_feats = atom_feats ,
22292242 atom_mask = atom_mask ,
2243+ missing_atom_mask = missing_atom_mask ,
22302244 atompair_feats = atompair_feats ,
22312245 atom_parent_ids = atom_parent_ids ,
22322246 mask = mask ,
@@ -2282,6 +2296,11 @@ def forward(
22822296
22832297 losses = losses * loss_weights
22842298
2299+ # if there are missing atoms, update the atom mask to not include them in the loss
2300+
2301+ if exists (missing_atom_mask ):
2302+ atom_mask = atom_mask & ~ missing_atom_mask
2303+
22852304 # account for atom mask
22862305
22872306 mse_loss = losses [atom_mask ].mean ()
@@ -3337,6 +3356,7 @@ def forward(
33373356 atompair_ids : Int ['b m m' ] | Int ['b nw {self.w} {self.w*2}' ] | None = None ,
33383357 is_molecule_mod : Bool ['b n num_mods' ] | None = None ,
33393358 atom_mask : Bool ['b m' ] | None = None ,
3359+ missing_atom_mask : Bool ['b m' ] | None = None ,
33403360 atom_parent_ids : Int ['b m' ] | None = None ,
33413361 token_bonds : Bool ['b n n' ] | None = None ,
33423362 msa : Float ['b s n d' ] | None = None ,
@@ -3656,6 +3676,7 @@ def forward(
36563676 (
36573677 atom_pos ,
36583678 atom_mask ,
3679+ missing_atom_mask ,
36593680 atom_feats ,
36603681 atom_parent_ids ,
36613682 atompair_feats ,
@@ -3679,6 +3700,7 @@ def forward(
36793700 for t in (
36803701 atom_pos ,
36813702 atom_mask ,
3703+ missing_atom_mask ,
36823704 atom_feats ,
36833705 atom_parent_ids ,
36843706 atompair_feats ,
@@ -3730,6 +3752,7 @@ def forward(
37303752 atom_feats = atom_feats ,
37313753 atompair_feats = atompair_feats ,
37323754 atom_parent_ids = atom_parent_ids ,
3755+ missing_atom_mask = missing_atom_mask ,
37333756 atom_mask = atom_mask ,
37343757 mask = mask ,
37353758 single_trunk_repr = single ,
0 commit comments