Skip to content

Commit 097cbff

Browse files
committed
move RSA calculation logic into own method
1 parent 86ee630 commit 097cbff

File tree

3 files changed

+79
-64
lines changed

3 files changed

+79
-64
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5707,6 +5707,77 @@ def _compute_unresolved_rasa(
57075707

57085708
return unresolved_rasa.mean()
57095709

5710+
@typecheck
5711+
def calc_atom_access_surface_score(
5712+
self,
5713+
atom_pos: Float['m 3'],
5714+
atom_type: Int[' m'],
5715+
fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency
5716+
atom_distance_min_thres = 1e-4
5717+
) -> Float['m']:
5718+
5719+
atom_radii: Float[' m'] = self.atom_radii[atom_type]
5720+
5721+
water_radii = self.atom_radii[-1]
5722+
5723+
# atom radii is always summed with water radii
5724+
5725+
atom_radii += water_radii
5726+
atom_radii_sq = atom_radii.pow(2) # always use square of radii or distance for comparison - save on sqrt
5727+
5728+
# write custom RSA function here
5729+
5730+
# first constitute the fibonacci sphere
5731+
5732+
num_surface_dots = fibonacci_sphere_n * 2 + 1
5733+
golden_ratio = 1. + sqrt(5.) / 2
5734+
weight = (4. * pi) / num_surface_dots
5735+
5736+
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5737+
5738+
lat = torch.asin((2. * arange) / num_surface_dots)
5739+
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio
5740+
5741+
# ein:
5742+
# sd - surface dots
5743+
# c - coordinate (3)
5744+
# i, j - source and target atom
5745+
5746+
unit_surface_dots: Float['sd 3'] = torch.stack((
5747+
lon.sin() * lat.cos(),
5748+
lon.cos() * lat.cos(),
5749+
lat.sin()
5750+
), dim = -1)
5751+
5752+
# first get atom relative positions + distance
5753+
# for determining whether to include pairs of atom in calculation for the `free` adjective
5754+
5755+
atom_rel_pos = einx.subtract('j c, i c -> i j c', atom_pos, atom_pos)
5756+
atom_rel_dist_sq = atom_rel_pos.pow(2).sum(dim = -1)
5757+
5758+
max_distance_include = einx.add('i, j -> i j', atom_radii, atom_radii).pow(2)
5759+
5760+
include_in_free_calc = (
5761+
(atom_rel_dist_sq < max_distance_include) &
5762+
(atom_rel_dist_sq > atom_distance_min_thres)
5763+
)
5764+
5765+
# overall logic
5766+
5767+
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)
5768+
5769+
dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)
5770+
5771+
target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq)
5772+
5773+
target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc)
5774+
5775+
is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure
5776+
5777+
score = reduce(is_free.float() * weight, 'm sd -> m', 'sum')
5778+
5779+
return score * atom_radii_sq
5780+
57105781
@typecheck
57115782
def _inhouse_compute_unresolved_rasa(
57125783
self,
@@ -5717,8 +5788,7 @@ def _inhouse_compute_unresolved_rasa(
57175788
molecule_atom_lens: Int[" n"],
57185789
atom_pos: Float["m 3"],
57195790
atom_mask: Bool[" m"],
5720-
fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency
5721-
atom_distance_min_thres = 1e-4
5791+
**rsa_calc_kwargs
57225792
) -> Float[""]:
57235793
"""Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
57245794
using inhouse rebuilt RSA calculation
@@ -5778,73 +5848,19 @@ def _inhouse_compute_unresolved_rasa(
57785848
structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos)
57795849
structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii)
57805850

5781-
atom_radii: Float[' m'] = self.atom_radii[structure_atom_type_for_radii]
5782-
5783-
water_radii = self.atom_radii[-1]
5784-
5785-
# atom radii is always summed with water radii
5786-
5787-
atom_radii += water_radii
5788-
atom_radii_sq = atom_radii.pow(2) # always use square of radii or distance for comparison - save on sqrt
5789-
5790-
# write custom RSA function here
5791-
5792-
# first constitute the fibonacci sphere
5851+
# per atom rsa calculation
57935852

5794-
num_surface_dots = fibonacci_sphere_n * 2 + 1
5795-
golden_ratio = 1. + sqrt(5.) / 2
5796-
weight = (4. * pi) / num_surface_dots
5797-
5798-
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5799-
5800-
lat = torch.asin((2. * arange) / num_surface_dots)
5801-
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio
5802-
5803-
# ein:
5804-
# sd - surface dots
5805-
# c - coordinate (3)
5806-
# i, j - source and target atom
5807-
5808-
unit_surface_dots: Float['sd 3'] = torch.stack((
5809-
lon.sin() * lat.cos(),
5810-
lon.cos() * lat.cos(),
5811-
lat.sin()
5812-
), dim = -1)
5813-
5814-
# first get atom relative positions + distance
5815-
# for determining whether to include pairs of atom in calculation for the `free` adjective
5816-
5817-
atom_rel_pos = einx.subtract('j c, i c -> i j c', structure_atom_pos, structure_atom_pos)
5818-
atom_rel_dist_sq = atom_rel_pos.pow(2).sum(dim = -1)
5819-
5820-
max_distance_include = einx.add('i, j -> i j', atom_radii, atom_radii).pow(2)
5821-
5822-
include_in_free_calc = (
5823-
(atom_rel_dist_sq < max_distance_include) &
5824-
(atom_rel_dist_sq > atom_distance_min_thres)
5853+
per_atom_access_surface_score = self.calc_atom_access_surface_score(
5854+
structure_atom_pos,
5855+
structure_atom_type_for_radii,
5856+
**rsa_calc_kwargs
58255857
)
58265858

5827-
# overall logic
5828-
5829-
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)
5830-
5831-
dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)
5832-
5833-
target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq)
5834-
5835-
target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc)
5836-
5837-
is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure
5838-
5839-
score = reduce(is_free.float() * weight, 'm sd -> m', 'sum')
5840-
5841-
per_atom_accessible_surface_score = score * atom_radii_sq
5842-
58435859
# sum up all surface scores for atoms per residue
58445860
# the final score seems to be the average of the rsa across all residues (selected by `chain_unresolved_residue_mask`)
58455861

58465862
rasa, mask = sum_pool_with_lens(
5847-
rearrange(per_atom_accessible_surface_score, '... -> 1 ... 1'),
5863+
rearrange(per_atom_access_surface_score, '... -> 1 ... 1'),
58485864
rearrange(chain_molecule_atom_lens, '... -> 1 ...')
58495865
)
58505866

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.47"
3+
version = "0.5.48"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
ComputeRankingScore,
3838
ConfidenceHeadLogits,
3939
ComputeModelSelectionScore,
40-
ComputeModelSelectionScore,
4140
collate_inputs_to_batched_atom_input,
4241
alphafold3_inputs_to_batched_atom_input,
4342
)

0 commit comments

Comments
 (0)