Skip to content

Commit eae1c01

Browse files
committed
reuse existing code from unresolved rasa to allow Alphafold3 to return List[Structure] and then have the cli be able to save a pdb file
1 parent f4e4bc5 commit eae1c01

File tree

4 files changed

+76
-10
lines changed

4 files changed

+76
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
105105

106106
from Bio.PDB.StructureBuilder import StructureBuilder
107+
from Bio.PDB.Structure import Structure
107108
from Bio.PDB.PDBIO import PDBIO
108109
from Bio.PDB.DSSP import DSSP
109110
import tempfile
@@ -5130,13 +5131,14 @@ def get_cid_molecule_type(
51305131

51315132

51325133
@typecheck
5133-
def _protein_structure_from_feature(
5134+
def protein_structure_from_feature(
51345135
asym_id: Int[" n"],
51355136
molecule_ids: Int[" n"],
51365137
molecule_atom_lens: Int[" n"],
51375138
atom_pos: Float["m 3"],
51385139
atom_mask: Bool[" m"],
5139-
) -> Bio.PDB.Structure.Structure:
5140+
) -> Structure:
5141+
51405142
"""Create structure for unresolved proteins.
51415143
51425144
:param atom_mask: True for valid atoms, False for missing/padding atoms
@@ -5626,7 +5628,7 @@ def _compute_unresolved_rasa(
56265628
chain_atom_pos = atom_pos[chain_mask_to_atom]
56275629
chain_atom_mask = atom_mask[chain_mask_to_atom]
56285630

5629-
structure = _protein_structure_from_feature(
5631+
structure = protein_structure_from_feature(
56305632
chain_asym_id,
56315633
chain_molecule_ids,
56325634
chain_molecule_atom_lens,
@@ -6381,6 +6383,7 @@ def forward(
63816383
return_all_diffused_atom_pos: bool = False,
63826384
return_confidence_head_logits: bool = False,
63836385
return_distogram_head_logits: bool = False,
6386+
return_bio_pdb_structures: bool = False,
63846387
num_rollout_steps: int | None = None,
63856388
rollout_show_tqdm_pbar: bool = False,
63866389
detach_when_recycling: bool = None,
@@ -6390,8 +6393,9 @@ def forward(
63906393
filepaths: List[str] | None = None
63916394
) -> (
63926395
Float['b m 3'] |
6396+
List[Structure] |
63936397
Float['ts b m 3'] |
6394-
Tuple[Float['b m 3'] | Float['ts b m 3'], ConfidenceHeadLogits | Alphafold3Logits] |
6398+
Tuple[Float['b m 3'] | List[Structure] | Float['ts b m 3'], ConfidenceHeadLogits | Alphafold3Logits] |
63956399
Float[''] |
63966400
Tuple[Float[''], LossBreakdown]
63976401
):
@@ -6666,6 +6670,22 @@ def forward(
66666670
if return_confidence_head_logits:
66676671
confidence_head_atom_pos_input = sampled_atom_pos.clone()
66686672

6673+
# convert sampled atom positions to bio pdb structures
6674+
6675+
if return_bio_pdb_structures:
6676+
assert not return_all_diffused_atom_pos
6677+
6678+
sampled_atom_pos = [
6679+
protein_structure_from_feature(*args)
6680+
for args in zip(
6681+
additional_molecule_feats[..., 2],
6682+
molecule_ids,
6683+
molecule_atom_lens,
6684+
sampled_atom_pos,
6685+
atom_mask
6686+
)
6687+
]
6688+
66696689
if not return_confidence_head_logits:
66706690
return sampled_atom_pos
66716691

alphafold3_pytorch/cli.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
alphafold3_inputs_to_batched_atom_input
1010
)
1111

12+
from Bio.PDB.PDBIO import PDBIO
13+
1214
# simple cli using click
1315

1416
@click.command()
1517
@click.option('-ckpt', '--checkpoint', type = str, help = 'path to alphafold3 checkpoint')
1618
@click.option('-p', '--protein', type = str, help = 'one protein sequence')
17-
@click.option('-o', '--output', type = str, help = 'output path', default = 'atompos.pt')
19+
@click.option('-o', '--output', type = str, help = 'output path', default = 'output.pdb')
1820
def cli(
1921
checkpoint: str,
2022
protein: str,
@@ -33,11 +35,13 @@ def cli(
3335
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input, atoms_per_window = alphafold3.atoms_per_window)
3436

3537
alphafold3.eval()
36-
sampled_atom_pos = alphafold3(**batched_atom_input.model_forward_dict())
38+
structure, = alphafold3(**batched_atom_input.model_forward_dict(), return_bio_pdb_structures = True)
3739

3840
output_path = Path(output)
3941
output_path.parents[0].mkdir(exist_ok = True, parents = True)
4042

41-
torch.save(sampled_atom_pos, str(output_path))
43+
pdb_writer = PDBIO()
44+
pdb_writer.set_structure(structure)
45+
pdb_writer.save(str(output_path))
4246

43-
print(f'atomic positions saved to {str(output_path)}')
47+
print(f'pdb saved to {str(output_path)}')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.43"
3+
version = "0.4.44"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_input.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def test_atom_dataset():
6767
# alphafold3 input
6868

6969
@pytest.mark.parametrize('directed_bonds', (False, True))
70-
def test_alphafold3_input(directed_bonds):
70+
def test_alphafold3_input(
71+
directed_bonds
72+
):
7173

7274
alphafold3_input = Alphafold3Input(
7375
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
@@ -119,6 +121,46 @@ def test_alphafold3_input(directed_bonds):
119121

120122
alphafold3(**batched_atom_input.model_forward_dict(), num_sample_steps = 1)
121123

124+
def test_return_bio_pdb_structures():
125+
126+
alphafold3_input = Alphafold3Input(
127+
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
128+
)
129+
130+
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input)
131+
132+
# feed it into alphafold3
133+
134+
alphafold3 = Alphafold3(
135+
dim_atom_inputs = 3,
136+
dim_atompair_inputs = 5,
137+
num_atom_embeds = 0,
138+
num_atompair_embeds = 0,
139+
atoms_per_window = 27,
140+
dim_template_feats = 108,
141+
num_dist_bins = 64,
142+
num_molecule_mods = 0,
143+
confidence_head_kwargs = dict(
144+
pairformer_depth = 1
145+
),
146+
template_embedder_kwargs = dict(
147+
pairformer_stack_depth = 1
148+
),
149+
msa_module_kwargs = dict(
150+
depth = 1
151+
),
152+
pairformer_stack = dict(
153+
depth = 2
154+
),
155+
diffusion_module_kwargs = dict(
156+
atom_encoder_depth = 1,
157+
token_transformer_depth = 1,
158+
atom_decoder_depth = 1,
159+
)
160+
)
161+
162+
alphafold3(**batched_atom_input.model_forward_dict(), num_sample_steps = 1, return_bio_pdb_structures = True)
163+
122164
def test_atompos_input():
123165

124166
contrived_protein = 'AG'

0 commit comments

Comments
 (0)