Skip to content

Commit d75ca0c

Browse files
authored
rewriting logic for computing rsa for model selection (#297)
1 parent 5e0002c commit d75ca0c

File tree

2 files changed

+228
-13
lines changed

2 files changed

+228
-13
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 196 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,17 @@ def mean_pool_with_lens(
345345
lens: Int['b n']
346346
) -> Float['b n d']:
347347

348+
summed, mask = sum_pool_with_lens(feats, lens)
349+
avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1))
350+
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
351+
return avg
352+
353+
@typecheck
354+
def sum_pool_with_lens(
355+
feats: Float['b m d'],
356+
lens: Int['b n']
357+
) -> tuple[Float['b n d'], Bool['b n']]:
358+
348359
seq_len = feats.shape[1]
349360

350361
mask = lens > 0
@@ -364,9 +375,7 @@ def mean_pool_with_lens(
364375
# subtract cumsum at one index from the previous one
365376
summed = sel_cumsum[:, 1:] - sel_cumsum[:, :-1]
366377

367-
avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1))
368-
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
369-
return avg
378+
return summed, mask
370379

371380
@typecheck
372381
def mean_pool_fixed_windows_with_mask(
@@ -5271,6 +5280,7 @@ def __init__(
52715280
is_fine_tuning: bool = False,
52725281
weight_dict_config: dict = None,
52735282
dssp_path: str = "mkdssp",
5283+
use_inhouse_rsa_calculation: bool = False
52745284
):
52755285
super().__init__()
52765286
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
@@ -5286,8 +5296,28 @@ def __init__(
52865296

52875297
self.dssp_path = dssp_path
52885298

5299+
self.use_inhouse_rsa_calculation = use_inhouse_rsa_calculation
5300+
5301+
atom_type_radii = tensor([
5302+
1.65, # 0 - nitrogen
5303+
1.87, # 1 - carbon alpha
5304+
1.76, # 2 - carbon
5305+
1.4, # 3 - oxygen
5306+
1.8, # 4 - side atoms
5307+
1.4 # 5 - water
5308+
])
5309+
5310+
self.atom_type_index = dict(
5311+
N = 0,
5312+
CA = 1,
5313+
C = 2,
5314+
O = 3
5315+
) # rest go to 4 (side chain atom)
5316+
5317+
self.register_buffer('atom_radii', atom_type_radii, persistent = False)
5318+
52895319
@property
5290-
def can_calculate_unresolved_protein_rasa(self):
5320+
def is_mkdssp_available(self):
52915321
"""Check if `mkdssp` is available.
52925322
52935323
:return: True if `mkdssp` is available
@@ -5298,6 +5328,10 @@ def can_calculate_unresolved_protein_rasa(self):
52985328
except sh.ErrorReturnCode_1:
52995329
return False
53005330

5331+
@property
5332+
def can_calculate_unresolved_protein_rasa(self):
5333+
return self.is_mkdssp_available or self.use_inhouse_rsa_calculation
5334+
53015335
@typecheck
53025336
def compute_gpde(
53035337
self,
@@ -5673,6 +5707,155 @@ def _compute_unresolved_rasa(
56735707

56745708
return unresolved_rasa.mean()
56755709

5710+
@typecheck
5711+
def _inhouse_compute_unresolved_rasa(
5712+
self,
5713+
unresolved_cid: int,
5714+
unresolved_residue_mask: Bool[" n"],
5715+
asym_id: Int[" n"],
5716+
molecule_ids: Int[" n"],
5717+
molecule_atom_lens: Int[" n"],
5718+
atom_pos: Float["m 3"],
5719+
atom_mask: Bool[" m"],
5720+
fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency
5721+
atom_distance_min_thres = 1e-4
5722+
) -> Float[""]:
5723+
"""Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
5724+
using inhouse rebuilt RSA calculation
5725+
5726+
unresolved_cid: asym_id for protein chains with unresolved residues
5727+
unresolved_residue_mask: True for unresolved residues, False for resolved residues
5728+
asym_id: asym_id for each residue
5729+
molecule_ids: molecule_ids for each residue
5730+
molecule_atom_lens: number of atoms for each residue
5731+
atom_pos: [m 3] atom positions
5732+
atom_mask: True for valid atoms, False for missing/padding atoms
5733+
:return: unresolved RASA
5734+
"""
5735+
5736+
assert self.can_calculate_unresolved_protein_rasa, "`mkdssp` needs to be installed"
5737+
5738+
num_atom = atom_pos.shape[0]
5739+
5740+
chain_mask = asym_id == unresolved_cid
5741+
chain_unresolved_residue_mask = unresolved_residue_mask[chain_mask]
5742+
chain_asym_id = asym_id[chain_mask]
5743+
chain_molecule_ids = molecule_ids[chain_mask]
5744+
chain_molecule_atom_lens = molecule_atom_lens[chain_mask]
5745+
5746+
chain_mask_to_atom = torch.repeat_interleave(chain_mask, molecule_atom_lens)
5747+
5748+
# if there's padding in num atom
5749+
num_pad = num_atom - molecule_atom_lens.sum()
5750+
if num_pad > 0:
5751+
chain_mask_to_atom = F.pad(chain_mask_to_atom, (0, num_pad), value=False)
5752+
5753+
chain_atom_pos = atom_pos[chain_mask_to_atom]
5754+
chain_atom_mask = atom_mask[chain_mask_to_atom]
5755+
5756+
structure = protein_structure_from_feature(
5757+
chain_asym_id,
5758+
chain_molecule_ids,
5759+
chain_molecule_atom_lens,
5760+
chain_atom_pos,
5761+
chain_atom_mask,
5762+
)
5763+
5764+
# use the structure as source of truth, matching what xluo did
5765+
5766+
structure_atom_pos = []
5767+
structure_atom_type_for_radii = []
5768+
side_atom_index = len(self.atom_type_index)
5769+
5770+
for atom in structure.get_atoms():
5771+
5772+
one_atom_pos = list(atom.get_vector())
5773+
one_atom_type = self.atom_type_index.get(atom.name, side_atom_index)
5774+
5775+
structure_atom_pos.append(one_atom_pos)
5776+
structure_atom_type_for_radii.append(one_atom_type)
5777+
5778+
structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos)
5779+
structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii)
5780+
5781+
atom_radii: Float[' m'] = self.atom_radii[structure_atom_type_for_radii]
5782+
5783+
water_radii = self.atom_radii[-1]
5784+
5785+
# atom radii is always summed with water radii
5786+
5787+
atom_radii += water_radii
5788+
atom_radii_sq = atom_radii.pow(2) # always use square of radii or distance for comparison - save on sqrt
5789+
5790+
# write custom RSA function here
5791+
5792+
# first constitute the fibonacci sphere
5793+
5794+
num_surface_dots = fibonacci_sphere_n * 2 + 1
5795+
golden_ratio = 1. + sqrt(5.) / 2
5796+
weight = (4. * pi) / num_surface_dots
5797+
5798+
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5799+
5800+
lat = torch.asin((2. * arange) / num_surface_dots)
5801+
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio
5802+
5803+
# ein:
5804+
# sd - surface dots
5805+
# c - coordinate (3)
5806+
# i, j - source and target atom
5807+
5808+
unit_surface_dots: Float['sd 3'] = torch.stack((
5809+
lon.sin() * lat.cos(),
5810+
lon.cos() * lat.cos(),
5811+
lat.sin()
5812+
), dim = -1)
5813+
5814+
# first get atom relative positions + distance
5815+
# for determining whether to include pairs of atom in calculation for the `free` adjective
5816+
5817+
atom_rel_pos = einx.subtract('j c, i c -> i j c', structure_atom_pos, structure_atom_pos)
5818+
atom_rel_dist_sq = atom_rel_pos.pow(2).sum(dim = -1)
5819+
5820+
max_distance_include = einx.add('i, j -> i j', atom_radii, atom_radii).pow(2)
5821+
5822+
include_in_free_calc = (
5823+
(atom_rel_dist_sq < max_distance_include) &
5824+
(atom_rel_dist_sq > atom_distance_min_thres)
5825+
)
5826+
5827+
# overall logic
5828+
5829+
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)
5830+
5831+
dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)
5832+
5833+
target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq)
5834+
5835+
target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc)
5836+
5837+
is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure
5838+
5839+
score = reduce(is_free.float() * weight, 'm sd -> m', 'sum')
5840+
5841+
per_atom_accessible_surface_score = score * atom_radii_sq
5842+
5843+
# sum up all surface scores for atoms per residue
5844+
# the final score seems to be the average of the rsa across all residues (selected by `chain_unresolved_residue_mask`)
5845+
5846+
rasa, mask = sum_pool_with_lens(
5847+
rearrange(per_atom_accessible_surface_score, '... -> 1 ... 1'),
5848+
rearrange(chain_molecule_atom_lens, '... -> 1 ...')
5849+
)
5850+
5851+
rasa = einx.where('b n, b n d, -> b n d', mask, rasa, 0.)
5852+
5853+
rasa = rearrange(rasa, '1 n 1 -> n')
5854+
5855+
unresolved_rasa = rasa[chain_unresolved_residue_mask]
5856+
5857+
return unresolved_rasa.mean()
5858+
56765859
@typecheck
56775860
def compute_unresolved_rasa(
56785861
self,
@@ -5707,8 +5890,16 @@ def compute_unresolved_rasa(
57075890
weight = weight_dict.get("unresolved", {}).get("unresolved", None)
57085891
assert weight, "Weight not found for unresolved"
57095892

5893+
# for migrating to a rewritten computation of RSA
5894+
# to remove mkdssp dependency for model selection
5895+
5896+
if self.use_inhouse_rsa_calculation:
5897+
compute_unresolved_rasa_function = self._inhouse_compute_unresolved_rasa
5898+
else:
5899+
compute_unresolved_rasa_function = self._compute_unresolved_rasa
5900+
57105901
unresolved_rasa = [
5711-
self._compute_unresolved_rasa(*args)
5902+
compute_unresolved_rasa_function(*args)
57125903
for args in zip(
57135904
unresolved_cid,
57145905
unresolved_residue_mask,

tests/test_af3.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from alphafold3_pytorch.inputs import (
6161
IS_MOLECULE_TYPES,
6262
IS_PROTEIN,
63+
IS_PROTEIN_INDEX,
6364
atom_ref_pos_to_atompair_inputs,
6465
molecule_to_atom_input,
6566
pdb_input_to_molecule_input,
@@ -1469,39 +1470,62 @@ def test_unresolved_protein_rasa():
14691470
DATA_TEST_PDB_ID[1:3],
14701471
f"{DATA_TEST_PDB_ID}-assembly1.cif",
14711472
)
1473+
14721474
pdb_input = PDBInput(mmcif_filepath, inference=True)
14731475

14741476
mol_input = pdb_input_to_molecule_input(pdb_input)
14751477
atom_input = molecule_to_atom_input(mol_input)
14761478
batched_atom_input = collate_inputs_to_batched_atom_input([atom_input], atoms_per_window=27)
1477-
batched_atom_input_dict = batched_atom_input.dict()
14781479

1479-
res_idx, token_idx, asym_id, entity_id, sym_id = batched_atom_input_dict['additional_molecule_feats'].unbind(dim = -1)
1480+
molecule_atom_lens = batched_atom_input.molecule_atom_lens
1481+
is_molecule_types = batched_atom_input.is_molecule_types
1482+
1483+
is_protein = is_molecule_types[0, ..., IS_PROTEIN_INDEX]
1484+
1485+
# random crop for inhouse for now
1486+
# todo - remove this and make efficient
1487+
is_protein[200:] = False
1488+
1489+
atom_is_from_protein = torch.repeat_interleave(is_protein, molecule_atom_lens[0])
1490+
1491+
res_idx, token_idx, asym_id, entity_id, sym_id = batched_atom_input.additional_molecule_feats.unbind(dim = -1)
14801492

14811493
cid = 1
14821494
res_chem_index = get_cid_molecule_type(
14831495
cid,
14841496
asym_id[0],
1485-
batched_atom_input_dict['is_molecule_types'][0])
1497+
is_molecule_types[0]
1498+
)
14861499

14871500
# only support protein for unresolved protein calculate
14881501
assert res_chem_index == IS_PROTEIN
14891502

14901503
unresolved_residue_mask = torch.randint(0, 2, asym_id.shape).bool()
14911504

1492-
compute_model_selection_score = ComputeModelSelectionScore()
1505+
compute_model_selection_score = ComputeModelSelectionScore(
1506+
use_inhouse_rsa_calculation = True
1507+
)
14931508

14941509
if not compute_model_selection_score.can_calculate_unresolved_protein_rasa:
14951510
pytest.skip("mkdssp not available for calculating unresolved protein rasa")
14961511

1512+
# only test with protein
1513+
1514+
unresolved_residue_mask = unresolved_residue_mask[:, is_protein]
1515+
asym_id = asym_id[:, is_protein]
1516+
molecule_ids = batched_atom_input.molecule_ids[:, is_protein]
1517+
atom_pos = batched_atom_input.atom_pos[:, atom_is_from_protein]
1518+
missing_atom_mask = batched_atom_input.missing_atom_mask[:, atom_is_from_protein]
1519+
molecule_atom_lens = molecule_atom_lens[:, is_protein]
1520+
14971521
unresolved_rasa = compute_model_selection_score.compute_unresolved_rasa(
14981522
unresolved_cid=[1],
14991523
unresolved_residue_mask=unresolved_residue_mask,
15001524
asym_id = asym_id,
1501-
molecule_ids=batched_atom_input_dict['molecule_ids'],
1502-
molecule_atom_lens=batched_atom_input_dict['molecule_atom_lens'],
1503-
atom_pos=batched_atom_input_dict['atom_pos'],
1504-
atom_mask=~batched_atom_input_dict['missing_atom_mask'])
1525+
molecule_ids=molecule_ids,
1526+
molecule_atom_lens=molecule_atom_lens,
1527+
atom_pos=atom_pos,
1528+
atom_mask=~missing_atom_mask)
15051529

15061530
def test_readme1():
15071531
alphafold3 = Alphafold3(dim_atom_inputs=77, dim_template_feats=108)

0 commit comments

Comments
 (0)