Skip to content

Commit b57414d

Browse files
committed
prepare pae bin centers 64 bins from 0 - 32 angstroms
1 parent 591cf63 commit b57414d

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3520,8 +3520,7 @@ def forward(
35203520

35213521
intermolecule_dist = torch.cdist(pred_molecule_pos, pred_molecule_pos, p = 2)
35223522

3523-
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', intermolecule_dist, self.atompair_dist_bins).abs()
3524-
dist_bin_indices = dist_from_dist_bins.argmin(dim = -1)
3523+
dist_bin_indices = distance_to_bins(intermolecule_dist, self.atompair_dist_bins)
35253524
pairwise_repr = pairwise_repr + self.dist_bin_pairwise_embed(dist_bin_indices)
35263525

35273526
# pairformer stack
@@ -3546,8 +3545,7 @@ def forward(
35463545

35473546
interatomic_dist = torch.cdist(pred_atom_pos, pred_atom_pos, p = 2)
35483547

3549-
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', interatomic_dist, self.atompair_dist_bins).abs()
3550-
dist_bin_indices = dist_from_dist_bins.argmin(dim = -1)
3548+
dist_bin_indices = distance_to_bins(interatomic_dist, self.atompair_dist_bins)
35513549
pairwise_repr = pairwise_repr + self.dist_bin_pairwise_embed(dist_bin_indices)
35523550

35533551
single_repr = single_repr + self.atom_feats_to_single(atom_feats)
@@ -4889,11 +4887,12 @@ def __init__(
48894887
num_atompair_embeds: int | None = None,
48904888
num_molecule_mods: int | None = DEFAULT_NUM_MOLECULE_MODS,
48914889
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
4890+
pae_bins: List[float] = torch.linspace(0.5, 32, 64).float().tolist(),
48924891
ignore_index = -1,
48934892
num_dist_bins: int | None = None,
48944893
num_plddt_bins = 50,
48954894
num_pde_bins = 64,
4896-
num_pae_bins = 64,
4895+
num_pae_bins: int | None = None,
48974896
sigma_data = 16,
48984897
num_rollout_steps = 20,
48994898
diffusion_num_augmentations = 4,
@@ -5126,21 +5125,30 @@ def __init__(
51265125
**edm_kwargs
51275126
)
51285127

5128+
self.num_rollout_steps = num_rollout_steps
5129+
51295130
# logit heads
51305131

51315132
distance_bins_tensor = Tensor(distance_bins)
51325133

51335134
self.register_buffer('distance_bins', distance_bins_tensor)
51345135
num_dist_bins = default(num_dist_bins, len(distance_bins_tensor))
51355136

5137+
51365138
assert len(distance_bins_tensor) == num_dist_bins, '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
51375139

51385140
self.distogram_head = DistogramHead(
51395141
dim_pairwise = dim_pairwise,
51405142
num_dist_bins = num_dist_bins
51415143
)
51425144

5143-
self.num_rollout_steps = num_rollout_steps
5145+
# pae bins
5146+
5147+
pae_bins_tensor = Tensor(pae_bins)
5148+
self.register_buffer('pae_bins', pae_bins_tensor)
5149+
num_pae_bins = len(pae_bins)
5150+
5151+
# confidence head
51445152

51455153
self.confidence_head = ConfidenceHead(
51465154
dim_single_inputs = dim_single_inputs,

0 commit comments

Comments
 (0)