@@ -3067,6 +3067,7 @@ def forward(
30673067 true_coords : Float ['b m_or_n 3' ],
30683068 pred_frames : Float ['b n 3 3' ],
30693069 true_frames : Float ['b n 3 3' ],
3070+ mask : Bool ['b m_or_n' ] | None = None ,
30703071 molecule_atom_lens : Int ['b n' ] | None = None
30713072 ) -> Float ['b m_or_n m_or_n' ]:
30723073 """
@@ -3085,17 +3086,20 @@ def forward(
30853086 pred_frames = batch_repeat_interleave (pred_frames , molecule_atom_lens )
30863087 true_frames = batch_repeat_interleave (true_frames , molecule_atom_lens )
30873088
3089+ if not exists (mask ) and exists (molecule_atom_lens ):
3090+ mask = batch_repeat_interleave (molecule_atom_lens > 0 , molecule_atom_lens )
3091+
30883092 # to pairs
30893093
3090- num_res = pred_coords .shape [1 ]
3094+ seq = pred_coords .shape [1 ]
30913095
30923096 pair2seq = partial (rearrange , pattern = 'b n m ... -> b (n m) ...' )
3093- seq2pair = partial (rearrange , pattern = 'b (n m) ... -> b n m ...' , n = num_res , m = num_res )
3097+ seq2pair = partial (rearrange , pattern = 'b (n m) ... -> b n m ...' , n = seq , m = seq )
30943098
3095- pair_pred_coords = pair2seq (repeat (pred_coords , 'b n d -> b n m d' , m = num_res ))
3096- pair_true_coords = pair2seq (repeat (true_coords , 'b n d -> b n m d' , m = num_res ))
3097- pair_pred_frames = pair2seq (repeat (pred_frames , 'b n d e -> b m n d e' , m = num_res ))
3098- pair_true_frames = pair2seq (repeat (true_frames , 'b n d e -> b m n d e' , m = num_res ))
3099+ pair_pred_coords = pair2seq (repeat (pred_coords , 'b n d -> b n m d' , m = seq ))
3100+ pair_true_coords = pair2seq (repeat (true_coords , 'b n d -> b n m d' , m = seq ))
3101+ pair_pred_frames = pair2seq (repeat (pred_frames , 'b n d e -> b m n d e' , m = seq ))
3102+ pair_true_frames = pair2seq (repeat (true_frames , 'b n d e -> b m n d e' , m = seq ))
30993103
31003104 # Express predicted coordinates in predicted frames
31013105 pred_coords_transformed = self .express_coordinates_in_frame (pair_pred_coords , pair_pred_frames )
@@ -3107,9 +3111,14 @@ def forward(
31073111 alignment_errors = torch .sqrt (
31083112 torch .sum ((pred_coords_transformed - true_coords_transformed ) ** 2 , dim = - 1 ) + self .eps
31093113 )
3110-
3114+
31113115 alignment_errors = seq2pair (alignment_errors )
31123116
3117+ # Masking
3118+ if exists (mask ):
3119+ pair_mask = to_pairwise_mask (mask )
3120+ alignment_errors = einx .where ('b i j, b i j, -> b i j' , pair_mask , alignment_errors , 0. )
3121+
31133122 return alignment_errors
31143123
31153124class CentreRandomAugmentation (Module ):
0 commit comments