Skip to content

Commit e99cfb4

Browse files
committed
gradio app should only be created after cli is invoked
1 parent ba0a326 commit e99cfb4

File tree

2 files changed

+164
-163
lines changed

2 files changed

+164
-163
lines changed

alphafold3_pytorch/app.py

Lines changed: 163 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -70,178 +70,179 @@ def delete_cache(request: gr.Request):
7070
shutil.rmtree(str(user_dir))
7171

7272

73-
with gr.Blocks(delete_cache=(600, 3600)) as gradio_app:
74-
entities = gr.State([])
75-
76-
with gr.Row():
77-
gr.Markdown("### AlphaFold3 PyTorch Web UI")
78-
79-
with gr.Row():
80-
gr.Column(scale=8)
81-
# upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100)
82-
clear_button = gr.Button("Clear", scale=1, min_width=100)
83-
84-
with gr.Row():
85-
with gr.Column(scale=1, min_width=150):
86-
mtype = gr.Dropdown(
87-
value="Protein",
88-
label="Molecule type",
89-
choices=["Protein", "DNA", "RNA", "Ligand", "Ion"],
90-
interactive=True,
91-
)
92-
with gr.Column(scale=1, min_width=80):
93-
c = gr.Number(
94-
value=1,
95-
label="Copies",
96-
interactive=True,
97-
)
98-
99-
with gr.Column(scale=8, min_width=200):
100-
101-
@gr.render(inputs=mtype)
102-
def render_sequence(mol_type):
103-
if mol_type in ["Protein", "DNA", "RNA"]:
104-
seq = gr.Textbox(
105-
label="Paste sequence or fasta",
106-
placeholder="Input",
107-
interactive=True,
108-
)
109-
elif mol_type == "Ligand":
110-
seq = gr.Dropdown(
111-
label="Select ligand",
112-
choices=[
113-
"ADP - Adenosine disphosphate",
114-
"ATP - Adenosine triphosphate",
115-
"AMP - Adenosine monophosphate",
116-
"GTP - Guanosine-5'-triphosphate",
117-
"GDP - Guanosine-5'-diphosphate",
118-
"FAD - Flavin adenine dinucleotide",
119-
"NAD - Nicotinamide-adenine-dinucleotide",
120-
"NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)",
121-
"NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)",
122-
"HEM - Heme",
123-
"HEC - Heme C",
124-
"OLA - Oleic acid",
125-
"MYR - Myristic acid",
126-
"CIT - Citric acid",
127-
"CLA - Chlorophyll A",
128-
"CHL - Chlorophyll B",
129-
"BCL - Bacteriochlorophyll A",
130-
"BCB - Bacteriochlorophyll B",
131-
],
132-
interactive=True,
133-
)
134-
elif mol_type == "Ion":
135-
seq = gr.Dropdown(
136-
label="Select ion",
137-
choices=[
138-
"Mg²⁺",
139-
"Zn²⁺",
140-
"Cl⁻",
141-
"Ca²⁺",
142-
"Na⁺",
143-
"Mn²⁺",
144-
"K⁺",
145-
"Fe³⁺",
146-
"Cu²⁺",
147-
"Co²⁺",
148-
],
149-
interactive=True,
150-
)
151-
152-
add_button.click(add_entity, inputs=[entities, mtype, c, seq], outputs=[entities])
153-
clear_button.click(lambda: ("Protein", 1, None), None, outputs=[mtype, c, seq])
154-
155-
add_button = gr.Button("Add entity", scale=1, min_width=100)
156-
157-
def add_entity(entities, mtype="Protein", c=1, seq=""):
158-
if seq is None or len(seq) == 0:
159-
gr.Info("Input required")
160-
return entities
161-
162-
seq_norm = seq.strip(" \t\n\r").upper()
163-
164-
if mtype in ["Protein", "DNA", "RNA"]:
165-
if mtype == "Protein" and any([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm]):
166-
gr.Info("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V")
167-
return entities
168-
169-
if mtype == "DNA" and any([x not in "ACGT" for x in seq_norm]):
170-
gr.Info("Invalid DNA sequence. Allowed characters: A, C, G, T")
171-
return entities
172-
173-
if mtype == "RNA" and any([x not in "ACGU" for x in seq_norm]):
174-
gr.Info("Invalid RNA sequence. Allowed characters: A, C, G, U")
175-
return entities
176-
177-
if len(seq) < 4:
178-
gr.Info("Minimum 4 characters required")
179-
return entities
180-
181-
elif mtype == "Ligand":
182-
if seq is None or len(seq) == 0:
183-
gr.Info("Select a ligand")
184-
return entities
185-
seq_norm = seq.split(" - ")[0]
186-
elif mtype == "Ion":
187-
if seq is None or len(seq) == 0:
188-
gr.Info("Select an ion")
189-
return entities
190-
seq_norm = "".join([x for x in seq if x.isalpha()])
191-
192-
new_entity = {"mol_type": mtype, "num_copies": c, "sequence": seq_norm}
193-
194-
return entities + [new_entity]
195-
196-
@gr.render(inputs=entities)
197-
def render_entities(entity_list):
198-
for idx, entity in enumerate(entity_list):
199-
with gr.Row():
200-
gr.Text(
201-
value=entity["mol_type"],
202-
label="Type",
203-
scale=1,
204-
min_width=90,
205-
interactive=False,
73+
def start_gradio_app():
74+
with gr.Blocks(delete_cache=(600, 3600)) as gradio_app:
75+
entities = gr.State([])
76+
77+
with gr.Row():
78+
gr.Markdown("### AlphaFold3 PyTorch Web UI")
79+
80+
with gr.Row():
81+
gr.Column(scale=8)
82+
# upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100)
83+
clear_button = gr.Button("Clear", scale=1, min_width=100)
84+
85+
with gr.Row():
86+
with gr.Column(scale=1, min_width=150):
87+
mtype = gr.Dropdown(
88+
value="Protein",
89+
label="Molecule type",
90+
choices=["Protein", "DNA", "RNA", "Ligand", "Ion"],
91+
interactive=True,
20692
)
207-
gr.Text(
208-
value=entity["num_copies"],
93+
with gr.Column(scale=1, min_width=80):
94+
c = gr.Number(
95+
value=1,
20996
label="Copies",
210-
scale=1,
211-
min_width=80,
212-
interactive=False,
97+
interactive=True,
21398
)
21499

215-
sequence = entity["sequence"]
216-
if entity["mol_type"] not in ["Ligand", "Ion"]:
217-
# Split every 10 characters, and add a \t after each split
218-
sequence = "\t".join([sequence[i : i + 10] for i in range(0, len(sequence), 10)])
219-
220-
gr.Text(
221-
value=sequence,
222-
label="Sequence",
223-
placeholder="Input",
224-
scale=7,
225-
min_width=200,
226-
interactive=False,
227-
)
100+
with gr.Column(scale=8, min_width=200):
101+
102+
@gr.render(inputs=mtype)
103+
def render_sequence(mol_type):
104+
if mol_type in ["Protein", "DNA", "RNA"]:
105+
seq = gr.Textbox(
106+
label="Paste sequence or fasta",
107+
placeholder="Input",
108+
interactive=True,
109+
)
110+
elif mol_type == "Ligand":
111+
seq = gr.Dropdown(
112+
label="Select ligand",
113+
choices=[
114+
"ADP - Adenosine disphosphate",
115+
"ATP - Adenosine triphosphate",
116+
"AMP - Adenosine monophosphate",
117+
"GTP - Guanosine-5'-triphosphate",
118+
"GDP - Guanosine-5'-diphosphate",
119+
"FAD - Flavin adenine dinucleotide",
120+
"NAD - Nicotinamide-adenine-dinucleotide",
121+
"NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)",
122+
"NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)",
123+
"HEM - Heme",
124+
"HEC - Heme C",
125+
"OLA - Oleic acid",
126+
"MYR - Myristic acid",
127+
"CIT - Citric acid",
128+
"CLA - Chlorophyll A",
129+
"CHL - Chlorophyll B",
130+
"BCL - Bacteriochlorophyll A",
131+
"BCB - Bacteriochlorophyll B",
132+
],
133+
interactive=True,
134+
)
135+
elif mol_type == "Ion":
136+
seq = gr.Dropdown(
137+
label="Select ion",
138+
choices=[
139+
"Mg²⁺",
140+
"Zn²⁺",
141+
"Cl⁻",
142+
"Ca²⁺",
143+
"Na⁺",
144+
"Mn²⁺",
145+
"K⁺",
146+
"Fe³⁺",
147+
"Cu²⁺",
148+
"Co²⁺",
149+
],
150+
interactive=True,
151+
)
152+
153+
add_button.click(add_entity, inputs=[entities, mtype, c, seq], outputs=[entities])
154+
clear_button.click(lambda: ("Protein", 1, None), None, outputs=[mtype, c, seq])
155+
156+
add_button = gr.Button("Add entity", scale=1, min_width=100)
157+
158+
def add_entity(entities, mtype="Protein", c=1, seq=""):
159+
if seq is None or len(seq) == 0:
160+
gr.Info("Input required")
161+
return entities
162+
163+
seq_norm = seq.strip(" \t\n\r").upper()
164+
165+
if mtype in ["Protein", "DNA", "RNA"]:
166+
if mtype == "Protein" and any([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm]):
167+
gr.Info("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V")
168+
return entities
169+
170+
if mtype == "DNA" and any([x not in "ACGT" for x in seq_norm]):
171+
gr.Info("Invalid DNA sequence. Allowed characters: A, C, G, T")
172+
return entities
173+
174+
if mtype == "RNA" and any([x not in "ACGU" for x in seq_norm]):
175+
gr.Info("Invalid RNA sequence. Allowed characters: A, C, G, U")
176+
return entities
177+
178+
if len(seq) < 4:
179+
gr.Info("Minimum 4 characters required")
180+
return entities
181+
182+
elif mtype == "Ligand":
183+
if seq is None or len(seq) == 0:
184+
gr.Info("Select a ligand")
185+
return entities
186+
seq_norm = seq.split(" - ")[0]
187+
elif mtype == "Ion":
188+
if seq is None or len(seq) == 0:
189+
gr.Info("Select an ion")
190+
return entities
191+
seq_norm = "".join([x for x in seq if x.isalpha()])
192+
193+
new_entity = {"mol_type": mtype, "num_copies": c, "sequence": seq_norm}
194+
195+
return entities + [new_entity]
196+
197+
@gr.render(inputs=entities)
198+
def render_entities(entity_list):
199+
for idx, entity in enumerate(entity_list):
200+
with gr.Row():
201+
gr.Text(
202+
value=entity["mol_type"],
203+
label="Type",
204+
scale=1,
205+
min_width=90,
206+
interactive=False,
207+
)
208+
gr.Text(
209+
value=entity["num_copies"],
210+
label="Copies",
211+
scale=1,
212+
min_width=80,
213+
interactive=False,
214+
)
228215

229-
del_button = gr.Button("🗑️", scale=0, min_width=50)
216+
sequence = entity["sequence"]
217+
if entity["mol_type"] not in ["Ligand", "Ion"]:
218+
# Split every 10 characters, and add a \t after each split
219+
sequence = "\t".join([sequence[i : i + 10] for i in range(0, len(sequence), 10)])
220+
221+
gr.Text(
222+
value=sequence,
223+
label="Sequence",
224+
placeholder="Input",
225+
scale=7,
226+
min_width=200,
227+
interactive=False,
228+
)
230229

231-
def delete(entity_id=idx):
232-
entity_list.pop(entity_id)
233-
return entity_list
230+
del_button = gr.Button("🗑️", scale=0, min_width=50)
234231

235-
del_button.click(delete, None, outputs=[entities])
232+
def delete(entity_id=idx):
233+
entity_list.pop(entity_id)
234+
return entity_list
236235

237-
pred_button = gr.Button("Predict", scale=1, min_width=100)
238-
output_mol = Molecule3D(label="Output structure", config={"backgroundColor": "black"})
236+
del_button.click(delete, None, outputs=[entities])
239237

240-
pred_button.click(fold, inputs=entities, outputs=output_mol)
241-
clear_button.click(lambda: ([], None), None, outputs=[entities, output_mol])
238+
pred_button = gr.Button("Predict", scale=1, min_width=100)
239+
output_mol = Molecule3D(label="Output structure", config={"backgroundColor": "black"})
242240

243-
gradio_app.unload(delete_cache)
241+
pred_button.click(fold, inputs=entities, outputs=output_mol)
242+
clear_button.click(lambda: ([], None), None, outputs=[entities, output_mol])
244243

244+
gradio_app.unload(delete_cache)
245+
gradio_app.launch()
245246

246247
# cli
247248
@click.command()
@@ -271,4 +272,4 @@ def app(checkpoint: str, cache_dir: str, precision: str):
271272
# dtype = torch.float32
272273
# model.to(device, dtype=dtype)
273274

274-
gradio_app.launch()
275+
start_gradio_app()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.31"
3+
version = "0.5.32"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)