@@ -6083,8 +6083,12 @@ def forward(
60836083 is_nucleotide = is_rna | is_dna
60846084 is_polymer = is_protein | is_rna | is_dna
60856085
6086- is_any_nucleotide_pair = to_pairwise_mask (is_nucleotide )
6087- is_any_polymer_pair = to_pairwise_mask (is_polymer )
6086+ is_any_nucleotide_pair = einx .logical_and (
6087+ '... i, ... j -> ... i j' , torch .ones_like (is_nucleotide ), is_nucleotide
6088+ )
6089+ is_any_polymer_pair = einx .logical_and (
6090+ '... i, ... j -> ... i j' , torch .ones_like (is_polymer ), is_polymer
6091+ )
60886092
60896093 inclusion_radius = torch .where (
60906094 is_any_nucleotide_pair ,
@@ -6094,7 +6098,11 @@ def forward(
60946098
60956099 is_token_center_atom = torch .zeros_like (atom_pos [..., 0 ], dtype = torch .bool )
60966100 is_token_center_atom [torch .arange (batch_size ).unsqueeze (1 ), molecule_atom_indices ] = True
6097- is_any_token_center_atom_pair = to_pairwise_mask (is_token_center_atom )
6101+ is_any_token_center_atom_pair = einx .logical_and (
6102+ '... i, ... j -> ... i j' ,
6103+ torch .ones_like (is_token_center_atom ),
6104+ is_token_center_atom ,
6105+ )
60986106
60996107 # compute masks, avoiding self term
61006108
0 commit comments