Skip to content

Commit 0b70ab9

Browse files
committed
calculate fibonacci sphere points only once
1 parent bf1d580 commit 0b70ab9

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5288,6 +5288,7 @@ def __init__(
52885288
contact_mask_threshold: float = 8.0,
52895289
is_fine_tuning: bool = False,
52905290
weight_dict_config: dict = None,
5291+
fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute
52915292
):
52925293
super().__init__()
52935294
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
@@ -5301,6 +5302,8 @@ def __init__(
53015302
self.register_buffer("dist_breaks", dist_breaks)
53025303
self.register_buffer('lddt_thresholds', torch.tensor([0.5, 1.0, 2.0, 4.0]))
53035304

5305+
# for rsa calculation
5306+
53045307
atom_type_radii = tensor([
53055308
1.65, # 0 - nitrogen
53065309
1.87, # 1 - carbon alpha
@@ -5319,6 +5322,31 @@ def __init__(
53195322

53205323
self.register_buffer('atom_radii', atom_type_radii, persistent = False)
53215324

5325+
# constitute the fibonacci sphere
5326+
5327+
num_surface_dots = fibonacci_sphere_n * 2 + 1
5328+
golden_ratio = 1. + sqrt(5.) / 2
5329+
weight = (4. * pi) / num_surface_dots
5330+
5331+
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5332+
5333+
lat = torch.asin((2. * arange) / num_surface_dots)
5334+
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio
5335+
5336+
# ein:
5337+
# sd - surface dots
5338+
# c - coordinate (3)
5339+
# i, j - source and target atom
5340+
5341+
unit_surface_dots: Float['sd 3'] = torch.stack((
5342+
lon.sin() * lat.cos(),
5343+
lon.cos() * lat.cos(),
5344+
lat.sin()
5345+
), dim = -1)
5346+
5347+
self.register_buffer('unit_surface_dots', unit_surface_dots)
5348+
self.surface_weight = weight
5349+
53225350
@property
53235351
def device(self):
53245352
return self.atom_radii.device
@@ -5651,7 +5679,6 @@ def calc_atom_access_surface_score(
56515679
atom_pos: Float['m 3'],
56525680
atom_type: Int['m'],
56535681
molecule_atom_lens: Int['n'] | None = None,
5654-
fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute
56555682
atom_distance_min_thres = 1e-4
56565683
) -> Float['m'] | Float['n']:
56575684

@@ -5666,28 +5693,6 @@ def calc_atom_access_surface_score(
56665693

56675694
# write custom RSA function here
56685695

5669-
# first constitute the fibonacci sphere
5670-
5671-
num_surface_dots = fibonacci_sphere_n * 2 + 1
5672-
golden_ratio = 1. + sqrt(5.) / 2
5673-
weight = (4. * pi) / num_surface_dots
5674-
5675-
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5676-
5677-
lat = torch.asin((2. * arange) / num_surface_dots)
5678-
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio
5679-
5680-
# ein:
5681-
# sd - surface dots
5682-
# c - coordinate (3)
5683-
# i, j - source and target atom
5684-
5685-
unit_surface_dots: Float['sd 3'] = torch.stack((
5686-
lon.sin() * lat.cos(),
5687-
lon.cos() * lat.cos(),
5688-
lat.sin()
5689-
), dim = -1)
5690-
56915696
# get atom relative positions + distance
56925697
# for determining whether to include pairs of atom in calculation for the `free` adjective
56935698

@@ -5715,7 +5720,7 @@ def calc_atom_access_surface_score(
57155720

57165721
# overall logic
57175722

5718-
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)
5723+
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, self.unit_surface_dots)
57195724

57205725
dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)
57215726

@@ -5725,7 +5730,7 @@ def calc_atom_access_surface_score(
57255730

57265731
is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure
57275732

5728-
score = reduce(is_free.float() * weight, 'm sd -> m', 'sum')
5733+
score = reduce(is_free.float() * self.surface_weight, 'm sd -> m', 'sum')
57295734

57305735
per_atom_access_surface_score = score * atom_radii_sq
57315736

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.55"
3+
version = "0.5.56"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)