Skip to content

Commit e00a599

Browse files
authored
#221 WIP: Basic interface, mapping various molecule types to AF3 input (#247)
App contains some bugs. After clicking the predict button, everything else is no longer responsive.
1 parent 7803788 commit e00a599

File tree

1 file changed

+247
-24
lines changed

1 file changed

+247
-24
lines changed

alphafold3_pytorch/app.py

Lines changed: 247 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,274 @@
11
import click
22
from pathlib import Path
33

4+
import secrets
5+
import shutil
46
import gradio as gr
7+
from gradio_molecule3d import Molecule3D
8+
from Bio.PDB import PDBIO
59

6-
from alphafold3_pytorch import (
7-
Alphafold3,
8-
Alphafold3Input,
9-
alphafold3_inputs_to_batched_atom_input
10-
)
10+
from alphafold3_pytorch import Alphafold3, Alphafold3Input
1111

1212
# constants
13-
1413
model = None
14+
cache_path = None
15+
pdb_writer = PDBIO()
16+
1517

16-
# main fold functoin
18+
# main fold function
19+
def fold(entities, request: gr.Request):
20+
proteins = []
21+
rnas = []
22+
dnas = []
23+
ligands = []
24+
ions = []
25+
for entity in entities:
26+
if entity["mol_type"] == "Protein":
27+
proteins.extend([entity["sequence"]] * entity["num_copies"])
28+
elif entity["mol_type"] == "RNA":
29+
rnas.extend([entity["sequence"]] * entity["num_copies"])
30+
elif entity["mol_type"] == "DNA":
31+
dnas.extend([entity["sequence"]] * entity["num_copies"])
32+
elif entity["mol_type"] == "Ligand":
33+
ligands.extend([entity["sequence"]] * entity["num_copies"])
34+
elif entity["mol_type"] == "Ion":
35+
ions.extend([entity["sequence"]] * entity["num_copies"])
1736

18-
def fold(protein):
37+
# Prepare the input for the model
1938
alphafold3_input = Alphafold3Input(
20-
proteins = [protein]
39+
proteins=proteins,
40+
ss_dna=dnas,
41+
ss_rna=rnas,
42+
ligands=ligands,
43+
metal_ions=ions,
2144
)
2245

46+
# Run the model inference in a separate thread
2347
model.eval()
24-
atom_pos, = model.forward_with_alphafold3_inputs(alphafold3_input)
48+
(structure,) = model.forward_with_alphafold3_inputs(
49+
alphafold3_inputs=alphafold3_input,
50+
return_bio_pdb_structures=True,
51+
)
52+
53+
global cache_path, pdb_writer
54+
output_path = cache_path / str(request.session_hash) / f"{secrets.token_urlsafe(8)}.pdb"
55+
output_path.parent.mkdir(exist_ok=True)
56+
57+
pdb_writer.set_structure(structure)
58+
pdb_writer.save(str(output_path))
59+
60+
return str(output_path)
2561

26-
return str(atom_pos.tolist())
2762

2863
# gradio
64+
def delete_cache(request: gr.Request):
65+
if not request.session_hash:
66+
return
2967

30-
gradio_app = gr.Interface(
31-
fn = fold,
32-
inputs = [
33-
"text"
34-
],
35-
outputs = [
36-
"text"
37-
],
38-
)
68+
user_dir: Path = cache_path / request.session_hash
69+
if user_dir.exists():
70+
shutil.rmtree(str(user_dir))
3971

40-
# cli
4172

73+
with gr.Blocks(delete_cache=(600, 3600)) as gradio_app:
74+
entities = gr.State([])
75+
76+
with gr.Row():
77+
gr.Markdown("### AlphaFold3 PyTorch Web UI")
78+
79+
with gr.Row():
80+
gr.Column(scale=8)
81+
# upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100)
82+
clear_button = gr.Button("Clear", scale=1, min_width=100)
83+
84+
with gr.Row():
85+
with gr.Column(scale=1, min_width=150):
86+
mtype = gr.Dropdown(
87+
value="Protein",
88+
label="Molecule type",
89+
choices=["Protein", "DNA", "RNA", "Ligand", "Ion"],
90+
interactive=True,
91+
)
92+
with gr.Column(scale=1, min_width=80):
93+
c = gr.Number(
94+
value=1,
95+
label="Copies",
96+
interactive=True,
97+
)
98+
99+
with gr.Column(scale=8, min_width=200):
100+
101+
@gr.render(inputs=mtype)
102+
def render_sequence(mol_type):
103+
if mol_type in ["Protein", "DNA", "RNA"]:
104+
seq = gr.Textbox(
105+
label="Paste sequence or fasta",
106+
placeholder="Input",
107+
interactive=True,
108+
)
109+
elif mol_type == "Ligand":
110+
seq = gr.Dropdown(
111+
label="Select ligand",
112+
choices=[
113+
"ADP - Adenosine disphosphate",
114+
"ATP - Adenosine triphosphate",
115+
"AMP - Adenosine monophosphate",
116+
"GTP - Guanosine-5'-triphosphate",
117+
"GDP - Guanosine-5'-diphosphate",
118+
"FAD - Flavin adenine dinucleotide",
119+
"NAD - Nicotinamide-adenine-dinucleotide",
120+
"NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)",
121+
"NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)",
122+
"HEM - Heme",
123+
"HEC - Heme C",
124+
"OLA - Oleic acid",
125+
"MYR - Myristic acid",
126+
"CIT - Citric acid",
127+
"CLA - Chlorophyll A",
128+
"CHL - Chlorophyll B",
129+
"BCL - Bacteriochlorophyll A",
130+
"BCB - Bacteriochlorophyll B",
131+
],
132+
interactive=True,
133+
)
134+
elif mol_type == "Ion":
135+
seq = gr.Dropdown(
136+
label="Select ion",
137+
choices=[
138+
"Mg²⁺",
139+
"Zn²⁺",
140+
"Cl⁻",
141+
"Ca²⁺",
142+
"Na⁺",
143+
"Mn²⁺",
144+
"K⁺",
145+
"Fe³⁺",
146+
"Cu²⁺",
147+
"Co²⁺",
148+
],
149+
interactive=True,
150+
)
151+
152+
add_button.click(add_entity, inputs=[entities, mtype, c, seq], outputs=[entities])
153+
clear_button.click(lambda: ("Protein", 1, None), None, outputs=[mtype, c, seq])
154+
155+
add_button = gr.Button("Add entity", scale=1, min_width=100)
156+
157+
def add_entity(entities, mtype="Protein", c=1, seq=""):
158+
if seq is None or len(seq) == 0:
159+
gr.Info("Input required")
160+
return entities
161+
162+
seq_norm = seq.strip(" \t\n\r").upper()
163+
164+
if mtype in ["Protein", "DNA", "RNA"]:
165+
if mtype == "Protein" and any([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm]):
166+
gr.Info("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V")
167+
return entities
168+
169+
if mtype == "DNA" and any([x not in "ACGT" for x in seq_norm]):
170+
gr.Info("Invalid DNA sequence. Allowed characters: A, C, G, T")
171+
return entities
172+
173+
if mtype == "RNA" and any([x not in "ACGU" for x in seq_norm]):
174+
gr.Info("Invalid RNA sequence. Allowed characters: A, C, G, U")
175+
return entities
176+
177+
if len(seq) < 4:
178+
gr.Info("Minimum 4 characters required")
179+
return entities
180+
181+
elif mtype == "Ligand":
182+
if seq is None or len(seq) == 0:
183+
gr.Info("Select a ligand")
184+
return entities
185+
seq_norm = seq.split(" - ")[0]
186+
elif mtype == "Ion":
187+
if seq is None or len(seq) == 0:
188+
gr.Info("Select an ion")
189+
return entities
190+
seq_norm = "".join([x for x in seq if x.isalpha()])
191+
192+
new_entity = {"mol_type": mtype, "num_copies": c, "sequence": seq_norm}
193+
194+
return entities + [new_entity]
195+
196+
@gr.render(inputs=entities)
197+
def render_entities(entity_list):
198+
for idx, entity in enumerate(entity_list):
199+
with gr.Row():
200+
gr.Text(
201+
value=entity["mol_type"],
202+
label="Type",
203+
scale=1,
204+
min_width=90,
205+
interactive=False,
206+
)
207+
gr.Text(
208+
value=entity["num_copies"],
209+
label="Copies",
210+
scale=1,
211+
min_width=80,
212+
interactive=False,
213+
)
214+
215+
sequence = entity["sequence"]
216+
if entity["mol_type"] not in ["Ligand", "Ion"]:
217+
# Split every 10 characters, and add a \t after each split
218+
sequence = "\t".join([sequence[i : i + 10] for i in range(0, len(sequence), 10)])
219+
220+
gr.Text(
221+
value=sequence,
222+
label="Sequence",
223+
placeholder="Input",
224+
scale=7,
225+
min_width=200,
226+
interactive=False,
227+
)
228+
229+
del_button = gr.Button("🗑️", scale=0, min_width=50)
230+
231+
def delete(entity_id=idx):
232+
entity_list.pop(entity_id)
233+
return entity_list
234+
235+
del_button.click(delete, None, outputs=[entities])
236+
237+
pred_button = gr.Button("Predict", scale=1, min_width=100)
238+
output_mol = Molecule3D(label="Output structure", config={"backgroundColor": "black"})
239+
240+
pred_button.click(fold, inputs=entities, outputs=output_mol)
241+
clear_button.click(lambda: ([], None), None, outputs=[entities, output_mol])
242+
243+
gradio_app.unload(delete_cache)
244+
245+
246+
# cli
42247
@click.command()
43-
@click.option('-ckpt', '--checkpoint', type = str, help = 'path to alphafold3 checkpoint', required = True)
44-
def app(checkpoint: str):
248+
@click.option("-ckpt", "--checkpoint", type=str, help="path to alphafold3 checkpoint", required=True)
249+
@click.option("-cache", "--cache-dir", type=str, help="path to output cache", required=False, default="cache")
250+
@click.option("-prec", "--precision", type=str, help="precision to use", required=False, default="float32")
251+
def app(checkpoint: str, cache_dir: str, precision: str):
45252
path = Path(checkpoint)
46-
assert path.exists(), 'checkpoint does not exist at path'
253+
assert path.exists(), "checkpoint does not exist at path"
254+
255+
global cache_path
256+
cache_path = Path(cache_dir)
257+
258+
if cache_path.exists():
259+
shutil.rmtree(str(cache_path))
260+
261+
cache_path.mkdir(exist_ok=True)
47262

48263
global model
49264
model = Alphafold3.init_and_load(str(path))
265+
# To device and quantize?
266+
# device = "cuda" if torch.cuda.is_available() else "cpu"
267+
# try:
268+
# dtype = getattr(torch, precision)
269+
# except AttributeError:
270+
# print(f"Invalid precision: {precision}. Using float32")
271+
# dtype = torch.float32
272+
# model.to(device, dtype=dtype)
50273

51274
gradio_app.launch()

0 commit comments

Comments
 (0)