Skip to content

Commit 64fa302

Browse files
committed
start setting up a simple cli for folding from terminal
1 parent 57d4661 commit 64fa302

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

alphafold3_pytorch/cli.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import click
2+
from pathlib import Path
3+
4+
import torch
5+
6+
from alphafold3_pytorch import (
7+
Alphafold3,
8+
Alphafold3Input,
9+
alphafold3_inputs_to_batched_atom_input
10+
)
11+
12+
# simple cli using click
13+
14+
@click.command()
15+
@click.option('-ckpt', '--checkpoint', type = str, help = 'path to alphafold3 checkpoint')
16+
@click.option('-p', '--protein', type = str, help = 'one protein sequence')
17+
@click.option('-o', '--output', type = str, help = 'output path', default = 'atompos.pt')
18+
def cli(
19+
checkpoint: str,
20+
protein: str,
21+
output: str
22+
):
23+
24+
checkpoint_path = Path(checkpoint)
25+
assert checkpoint_path.exists(), f'alphafold3 checkpoint must exist at {str(checkpoint_path)}'
26+
27+
alphafold3_input = Alphafold3Input(
28+
proteins = [protein],
29+
)
30+
31+
alphafold3 = Alphafold3.init_and_load(checkpoint_path)
32+
33+
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input, atoms_per_window = alphafold3.atoms_per_window)
34+
35+
alphafold3.eval()
36+
sampled_atom_pos = alphafold3(**batched_atom_input.model_forward_dict())
37+
38+
output_path = Path(output)
39+
output_path.parents[0].mkdir(exist_ok = True, parents = True)
40+
41+
torch.save(sampled_atom_pos, str(output_path))
42+
43+
print(f'atomic positions saved to {str(output_path)}')

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.37"
3+
version = "0.4.38"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -27,7 +27,7 @@ dependencies = [
2727
"adam-atan2-pytorch>=0.0.8",
2828
"beartype",
2929
"biopython>=1.83",
30-
"click",
30+
"click>=8.1",
3131
"CoLT5-attention>=0.11.0",
3232
"einops>=0.8.0",
3333
"einx>=0.2.2",
@@ -72,6 +72,9 @@ test = [
7272
"pytest-shard",
7373
]
7474

75+
[project.scripts]
76+
alphafold3_pytorch = "alphafold3_pytorch.cli:cli"
77+
7578
[build-system]
7679
requires = ["hatchling"]
7780
build-backend = "hatchling.build"

tests/test_cli.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
os.environ['TYPECHECK'] = 'True'
3+
os.environ['DEBUG'] = 'True'
4+
from shutil import rmtree
5+
6+
import torch
7+
8+
from alphafold3_pytorch.cli import cli
9+
10+
from alphafold3_pytorch.alphafold3 import (
11+
Alphafold3
12+
)
13+
14+
def test_cli():
15+
alphafold3 = Alphafold3(
16+
dim_atom_inputs = 3,
17+
dim_template_feats = 44,
18+
num_molecule_mods = 0
19+
)
20+
21+
checkpoint_path = './test-folder/test-cli-alphafold3.pt'
22+
alphafold3.save(checkpoint_path, overwrite = True)
23+
24+
cli(['--checkpoint', checkpoint_path, '--protein', 'AG', '--output', './test-folder/atom-pos.pt'], standalone_mode = False)
25+
26+
rmtree('./test-folder')

0 commit comments

Comments
 (0)