@@ -70,178 +70,179 @@ def delete_cache(request: gr.Request):
7070 shutil .rmtree (str (user_dir ))
7171
7272
73- with gr .Blocks (delete_cache = (600 , 3600 )) as gradio_app :
74- entities = gr .State ([])
75-
76- with gr .Row ():
77- gr .Markdown ("### AlphaFold3 PyTorch Web UI" )
78-
79- with gr .Row ():
80- gr .Column (scale = 8 )
81- # upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100)
82- clear_button = gr .Button ("Clear" , scale = 1 , min_width = 100 )
83-
84- with gr .Row ():
85- with gr .Column (scale = 1 , min_width = 150 ):
86- mtype = gr .Dropdown (
87- value = "Protein" ,
88- label = "Molecule type" ,
89- choices = ["Protein" , "DNA" , "RNA" , "Ligand" , "Ion" ],
90- interactive = True ,
91- )
92- with gr .Column (scale = 1 , min_width = 80 ):
93- c = gr .Number (
94- value = 1 ,
95- label = "Copies" ,
96- interactive = True ,
97- )
98-
99- with gr .Column (scale = 8 , min_width = 200 ):
100-
101- @gr .render (inputs = mtype )
102- def render_sequence (mol_type ):
103- if mol_type in ["Protein" , "DNA" , "RNA" ]:
104- seq = gr .Textbox (
105- label = "Paste sequence or fasta" ,
106- placeholder = "Input" ,
107- interactive = True ,
108- )
109- elif mol_type == "Ligand" :
110- seq = gr .Dropdown (
111- label = "Select ligand" ,
112- choices = [
113- "ADP - Adenosine disphosphate" ,
114- "ATP - Adenosine triphosphate" ,
115- "AMP - Adenosine monophosphate" ,
116- "GTP - Guanosine-5'-triphosphate" ,
117- "GDP - Guanosine-5'-diphosphate" ,
118- "FAD - Flavin adenine dinucleotide" ,
119- "NAD - Nicotinamide-adenine-dinucleotide" ,
120- "NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)" ,
121- "NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)" ,
122- "HEM - Heme" ,
123- "HEC - Heme C" ,
124- "OLA - Oleic acid" ,
125- "MYR - Myristic acid" ,
126- "CIT - Citric acid" ,
127- "CLA - Chlorophyll A" ,
128- "CHL - Chlorophyll B" ,
129- "BCL - Bacteriochlorophyll A" ,
130- "BCB - Bacteriochlorophyll B" ,
131- ],
132- interactive = True ,
133- )
134- elif mol_type == "Ion" :
135- seq = gr .Dropdown (
136- label = "Select ion" ,
137- choices = [
138- "Mg²⁺" ,
139- "Zn²⁺" ,
140- "Cl⁻" ,
141- "Ca²⁺" ,
142- "Na⁺" ,
143- "Mn²⁺" ,
144- "K⁺" ,
145- "Fe³⁺" ,
146- "Cu²⁺" ,
147- "Co²⁺" ,
148- ],
149- interactive = True ,
150- )
151-
152- add_button .click (add_entity , inputs = [entities , mtype , c , seq ], outputs = [entities ])
153- clear_button .click (lambda : ("Protein" , 1 , None ), None , outputs = [mtype , c , seq ])
154-
155- add_button = gr .Button ("Add entity" , scale = 1 , min_width = 100 )
156-
157- def add_entity (entities , mtype = "Protein" , c = 1 , seq = "" ):
158- if seq is None or len (seq ) == 0 :
159- gr .Info ("Input required" )
160- return entities
161-
162- seq_norm = seq .strip (" \t \n \r " ).upper ()
163-
164- if mtype in ["Protein" , "DNA" , "RNA" ]:
165- if mtype == "Protein" and any ([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm ]):
166- gr .Info ("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V" )
167- return entities
168-
169- if mtype == "DNA" and any ([x not in "ACGT" for x in seq_norm ]):
170- gr .Info ("Invalid DNA sequence. Allowed characters: A, C, G, T" )
171- return entities
172-
173- if mtype == "RNA" and any ([x not in "ACGU" for x in seq_norm ]):
174- gr .Info ("Invalid RNA sequence. Allowed characters: A, C, G, U" )
175- return entities
176-
177- if len (seq ) < 4 :
178- gr .Info ("Minimum 4 characters required" )
179- return entities
180-
181- elif mtype == "Ligand" :
182- if seq is None or len (seq ) == 0 :
183- gr .Info ("Select a ligand" )
184- return entities
185- seq_norm = seq .split (" - " )[0 ]
186- elif mtype == "Ion" :
187- if seq is None or len (seq ) == 0 :
188- gr .Info ("Select an ion" )
189- return entities
190- seq_norm = "" .join ([x for x in seq if x .isalpha ()])
191-
192- new_entity = {"mol_type" : mtype , "num_copies" : c , "sequence" : seq_norm }
193-
194- return entities + [new_entity ]
195-
196- @gr .render (inputs = entities )
197- def render_entities (entity_list ):
198- for idx , entity in enumerate (entity_list ):
199- with gr .Row ():
200- gr .Text (
201- value = entity ["mol_type" ],
202- label = "Type" ,
203- scale = 1 ,
204- min_width = 90 ,
205- interactive = False ,
73+ def start_gradio_app ():
74+ with gr .Blocks (delete_cache = (600 , 3600 )) as gradio_app :
75+ entities = gr .State ([])
76+
77+ with gr .Row ():
78+ gr .Markdown ("### AlphaFold3 PyTorch Web UI" )
79+
80+ with gr .Row ():
81+ gr .Column (scale = 8 )
82+ # upload_json_button = gr.Button("Upload JSON", scale=1, min_width=100)
83+ clear_button = gr .Button ("Clear" , scale = 1 , min_width = 100 )
84+
85+ with gr .Row ():
86+ with gr .Column (scale = 1 , min_width = 150 ):
87+ mtype = gr .Dropdown (
88+ value = "Protein" ,
89+ label = "Molecule type" ,
90+ choices = ["Protein" , "DNA" , "RNA" , "Ligand" , "Ion" ],
91+ interactive = True ,
20692 )
207- gr .Text (
208- value = entity ["num_copies" ],
93+ with gr .Column (scale = 1 , min_width = 80 ):
94+ c = gr .Number (
95+ value = 1 ,
20996 label = "Copies" ,
210- scale = 1 ,
211- min_width = 80 ,
212- interactive = False ,
97+ interactive = True ,
21398 )
21499
215- sequence = entity ["sequence" ]
216- if entity ["mol_type" ] not in ["Ligand" , "Ion" ]:
217- # Split every 10 characters, and add a \t after each split
218- sequence = "\t " .join ([sequence [i : i + 10 ] for i in range (0 , len (sequence ), 10 )])
219-
220- gr .Text (
221- value = sequence ,
222- label = "Sequence" ,
223- placeholder = "Input" ,
224- scale = 7 ,
225- min_width = 200 ,
226- interactive = False ,
227- )
100+ with gr .Column (scale = 8 , min_width = 200 ):
101+
102+ @gr .render (inputs = mtype )
103+ def render_sequence (mol_type ):
104+ if mol_type in ["Protein" , "DNA" , "RNA" ]:
105+ seq = gr .Textbox (
106+ label = "Paste sequence or fasta" ,
107+ placeholder = "Input" ,
108+ interactive = True ,
109+ )
110+ elif mol_type == "Ligand" :
111+ seq = gr .Dropdown (
112+ label = "Select ligand" ,
113+ choices = [
114+ "ADP - Adenosine disphosphate" ,
115+ "ATP - Adenosine triphosphate" ,
116+ "AMP - Adenosine monophosphate" ,
117+ "GTP - Guanosine-5'-triphosphate" ,
118+ "GDP - Guanosine-5'-diphosphate" ,
119+ "FAD - Flavin adenine dinucleotide" ,
120+ "NAD - Nicotinamide-adenine-dinucleotide" ,
121+ "NAP - Nicotinamide-adenine-dinucleotide phosphate (NADP)" ,
122+ "NDP - Dihydro-nicotinamide-adenine-dinucleotide-phosphate (NADPH)" ,
123+ "HEM - Heme" ,
124+ "HEC - Heme C" ,
125+ "OLA - Oleic acid" ,
126+ "MYR - Myristic acid" ,
127+ "CIT - Citric acid" ,
128+ "CLA - Chlorophyll A" ,
129+ "CHL - Chlorophyll B" ,
130+ "BCL - Bacteriochlorophyll A" ,
131+ "BCB - Bacteriochlorophyll B" ,
132+ ],
133+ interactive = True ,
134+ )
135+ elif mol_type == "Ion" :
136+ seq = gr .Dropdown (
137+ label = "Select ion" ,
138+ choices = [
139+ "Mg²⁺" ,
140+ "Zn²⁺" ,
141+ "Cl⁻" ,
142+ "Ca²⁺" ,
143+ "Na⁺" ,
144+ "Mn²⁺" ,
145+ "K⁺" ,
146+ "Fe³⁺" ,
147+ "Cu²⁺" ,
148+ "Co²⁺" ,
149+ ],
150+ interactive = True ,
151+ )
152+
153+ add_button .click (add_entity , inputs = [entities , mtype , c , seq ], outputs = [entities ])
154+ clear_button .click (lambda : ("Protein" , 1 , None ), None , outputs = [mtype , c , seq ])
155+
156+ add_button = gr .Button ("Add entity" , scale = 1 , min_width = 100 )
157+
158+ def add_entity (entities , mtype = "Protein" , c = 1 , seq = "" ):
159+ if seq is None or len (seq ) == 0 :
160+ gr .Info ("Input required" )
161+ return entities
162+
163+ seq_norm = seq .strip (" \t \n \r " ).upper ()
164+
165+ if mtype in ["Protein" , "DNA" , "RNA" ]:
166+ if mtype == "Protein" and any ([x not in "ARDCQEGHILKMNFPSTWYV" for x in seq_norm ]):
167+ gr .Info ("Invalid protein sequence. Allowed characters: A, R, D, C, Q, E, G, H, I, L, K, M, N, F, P, S, T, W, Y, V" )
168+ return entities
169+
170+ if mtype == "DNA" and any ([x not in "ACGT" for x in seq_norm ]):
171+ gr .Info ("Invalid DNA sequence. Allowed characters: A, C, G, T" )
172+ return entities
173+
174+ if mtype == "RNA" and any ([x not in "ACGU" for x in seq_norm ]):
175+ gr .Info ("Invalid RNA sequence. Allowed characters: A, C, G, U" )
176+ return entities
177+
178+ if len (seq ) < 4 :
179+ gr .Info ("Minimum 4 characters required" )
180+ return entities
181+
182+ elif mtype == "Ligand" :
183+ if seq is None or len (seq ) == 0 :
184+ gr .Info ("Select a ligand" )
185+ return entities
186+ seq_norm = seq .split (" - " )[0 ]
187+ elif mtype == "Ion" :
188+ if seq is None or len (seq ) == 0 :
189+ gr .Info ("Select an ion" )
190+ return entities
191+ seq_norm = "" .join ([x for x in seq if x .isalpha ()])
192+
193+ new_entity = {"mol_type" : mtype , "num_copies" : c , "sequence" : seq_norm }
194+
195+ return entities + [new_entity ]
196+
197+ @gr .render (inputs = entities )
198+ def render_entities (entity_list ):
199+ for idx , entity in enumerate (entity_list ):
200+ with gr .Row ():
201+ gr .Text (
202+ value = entity ["mol_type" ],
203+ label = "Type" ,
204+ scale = 1 ,
205+ min_width = 90 ,
206+ interactive = False ,
207+ )
208+ gr .Text (
209+ value = entity ["num_copies" ],
210+ label = "Copies" ,
211+ scale = 1 ,
212+ min_width = 80 ,
213+ interactive = False ,
214+ )
228215
229- del_button = gr .Button ("🗑️" , scale = 0 , min_width = 50 )
216+ sequence = entity ["sequence" ]
217+ if entity ["mol_type" ] not in ["Ligand" , "Ion" ]:
218+ # Split every 10 characters, and add a \t after each split
219+ sequence = "\t " .join ([sequence [i : i + 10 ] for i in range (0 , len (sequence ), 10 )])
220+
221+ gr .Text (
222+ value = sequence ,
223+ label = "Sequence" ,
224+ placeholder = "Input" ,
225+ scale = 7 ,
226+ min_width = 200 ,
227+ interactive = False ,
228+ )
230229
231- def delete (entity_id = idx ):
232- entity_list .pop (entity_id )
233- return entity_list
230+ del_button = gr .Button ("🗑️" , scale = 0 , min_width = 50 )
234231
235- del_button .click (delete , None , outputs = [entities ])
232+ def delete (entity_id = idx ):
233+ entity_list .pop (entity_id )
234+ return entity_list
236235
237- pred_button = gr .Button ("Predict" , scale = 1 , min_width = 100 )
238- output_mol = Molecule3D (label = "Output structure" , config = {"backgroundColor" : "black" })
236+ del_button .click (delete , None , outputs = [entities ])
239237
240- pred_button . click ( fold , inputs = entities , outputs = output_mol )
241- clear_button . click ( lambda : ([], None ), None , outputs = [ entities , output_mol ] )
238+ pred_button = gr . Button ( "Predict" , scale = 1 , min_width = 100 )
239+ output_mol = Molecule3D ( label = "Output structure" , config = { "backgroundColor" : "black" } )
242240
243- gradio_app .unload (delete_cache )
241+ pred_button .click (fold , inputs = entities , outputs = output_mol )
242+ clear_button .click (lambda : ([], None ), None , outputs = [entities , output_mol ])
244243
244+ gradio_app .unload (delete_cache )
245+ gradio_app .launch ()
245246
246247# cli
247248@click .command ()
@@ -271,4 +272,4 @@ def app(checkpoint: str, cache_dir: str, precision: str):
271272 # dtype = torch.float32
272273 # model.to(device, dtype=dtype)
273274
274- gradio_app . launch ()
275+ start_gradio_app ()
0 commit comments