Skip to content

Commit bf1d580

Browse files
committed
able to compute rsa on gpu
1 parent 5f97fcb commit bf1d580

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5319,6 +5319,10 @@ def __init__(
53195319

53205320
self.register_buffer('atom_radii', atom_type_radii, persistent = False)
53215321

5322+
@property
5323+
def device(self):
5324+
return self.atom_radii.device
5325+
53225326
@typecheck
53235327
def compute_gpde(
53245328
self,
@@ -5629,10 +5633,10 @@ def calc_atom_access_surface_score_from_structure(
56295633
structure_atom_pos.append(one_atom_pos)
56305634
structure_atom_type_for_radii.append(one_atom_type)
56315635

5632-
structure_atom_pos: Float['m 3'] = tensor(structure_atom_pos)
5633-
structure_atom_type_for_radii: Int['m'] = tensor(structure_atom_type_for_radii)
5636+
structure_atom_pos: Float['m 3'] = tensor(structure_atom_pos, device = self.device)
5637+
structure_atom_type_for_radii: Int['m'] = tensor(structure_atom_type_for_radii, device = self.device)
56345638

5635-
structure_atoms_per_residue: Int['n'] = tensor([len([*residue.get_atoms()]) for residue in structure.get_residues()]).long()
5639+
structure_atoms_per_residue: Int['n'] = tensor([len([*residue.get_atoms()]) for residue in structure.get_residues()], device = self.device).long()
56365640

56375641
return self.calc_atom_access_surface_score(
56385642
atom_pos = structure_atom_pos,
@@ -5668,7 +5672,7 @@ def calc_atom_access_surface_score(
56685672
golden_ratio = 1. + sqrt(5.) / 2
56695673
weight = (4. * pi) / num_surface_dots
56705674

5671-
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5675+
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
56725676

56735677
lat = torch.asin((2. * arange) / num_surface_dots)
56745678
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.54"
3+
version = "0.5.55"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)