Skip to content

Commit 3c8e86e

Browse files
committed
account for masking
1 parent 23690bb commit 3c8e86e

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3067,6 +3067,7 @@ def forward(
30673067
true_coords: Float['b m_or_n 3'],
30683068
pred_frames: Float['b n 3 3'],
30693069
true_frames: Float['b n 3 3'],
3070+
mask: Bool['b m_or_n'] | None = None,
30703071
molecule_atom_lens: Int['b n'] | None = None
30713072
) -> Float['b m_or_n m_or_n']:
30723073
"""
@@ -3085,17 +3086,20 @@ def forward(
30853086
pred_frames = batch_repeat_interleave(pred_frames, molecule_atom_lens)
30863087
true_frames = batch_repeat_interleave(true_frames, molecule_atom_lens)
30873088

3089+
if not exists(mask) and exists(molecule_atom_lens):
3090+
mask = batch_repeat_interleave(molecule_atom_lens > 0, molecule_atom_lens)
3091+
30883092
# to pairs
30893093

3090-
num_res = pred_coords.shape[1]
3094+
seq = pred_coords.shape[1]
30913095

30923096
pair2seq = partial(rearrange, pattern='b n m ... -> b (n m) ...')
3093-
seq2pair = partial(rearrange, pattern='b (n m) ... -> b n m ...', n = num_res, m = num_res)
3097+
seq2pair = partial(rearrange, pattern='b (n m) ... -> b n m ...', n = seq, m = seq)
30943098

3095-
pair_pred_coords = pair2seq(repeat(pred_coords, 'b n d -> b n m d', m = num_res))
3096-
pair_true_coords = pair2seq(repeat(true_coords, 'b n d -> b n m d', m = num_res))
3097-
pair_pred_frames = pair2seq(repeat(pred_frames, 'b n d e -> b m n d e', m = num_res))
3098-
pair_true_frames = pair2seq(repeat(true_frames, 'b n d e -> b m n d e', m = num_res))
3099+
pair_pred_coords = pair2seq(repeat(pred_coords, 'b n d -> b n m d', m = seq))
3100+
pair_true_coords = pair2seq(repeat(true_coords, 'b n d -> b n m d', m = seq))
3101+
pair_pred_frames = pair2seq(repeat(pred_frames, 'b n d e -> b m n d e', m = seq))
3102+
pair_true_frames = pair2seq(repeat(true_frames, 'b n d e -> b m n d e', m = seq))
30993103

31003104
# Express predicted coordinates in predicted frames
31013105
pred_coords_transformed = self.express_coordinates_in_frame(pair_pred_coords, pair_pred_frames)
@@ -3107,9 +3111,14 @@ def forward(
31073111
alignment_errors = torch.sqrt(
31083112
torch.sum((pred_coords_transformed - true_coords_transformed) ** 2, dim=-1) + self.eps
31093113
)
3110-
3114+
31113115
alignment_errors = seq2pair(alignment_errors)
31123116

3117+
# Masking
3118+
if exists(mask):
3119+
pair_mask = to_pairwise_mask(mask)
3120+
alignment_errors = einx.where('b i j, b i j, -> b i j', pair_mask, alignment_errors, 0.)
3121+
31133122
return alignment_errors
31143123

31153124
class CentreRandomAugmentation(Module):

0 commit comments

Comments
 (0)