Skip to content

Commit d158479

Browse files
committed
prepare for maybe nim speedup with fallback
1 parent 403921c commit d158479

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

alphafold3_pytorch/app.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@
33

44
import secrets
55
import shutil
6-
import gradio as gr
7-
from gradio_molecule3d import Molecule3D
86
from Bio.PDB import PDBIO
97

108
from alphafold3_pytorch import Alphafold3, Alphafold3Input
119

1210
# constants
11+
1312
model = None
1413
cache_path = None
1514
pdb_writer = PDBIO()
1615

17-
1816
# main fold function
19-
def fold(entities, request: gr.Request):
17+
18+
def fold(entities, request):
2019
proteins = []
2120
rnas = []
2221
dnas = []
@@ -59,9 +58,9 @@ def fold(entities, request: gr.Request):
5958

6059
return str(output_path)
6160

62-
6361
# gradio
64-
def delete_cache(request: gr.Request):
62+
63+
def delete_cache(request):
6564
if not request.session_hash:
6665
return
6766

@@ -71,6 +70,9 @@ def delete_cache(request: gr.Request):
7170

7271

7372
def start_gradio_app():
73+
import gradio as gr
74+
from gradio_molecule3d import Molecule3D
75+
7476
with gr.Blocks(delete_cache=(600, 3600)) as gradio_app:
7577
entities = gr.State([])
7678

alphafold3_pytorch/tensor_typing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from __future__ import annotations
2+
3+
import sh
24
from functools import partial
35
import importlib.metadata
6+
from packaging import version
7+
48
import torch
59
import numpy as np
610

@@ -87,6 +91,18 @@ def package_available(package_name: str) -> bool:
8791
else:
8892
checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant = False)
8993

94+
# whether to use Nim or not, depending on if available and version is adequate
95+
96+
try:
97+
sh.which('nim')
98+
HAS_NIM = True
99+
NIM_VERSION = sh.nim(eval = 'echo NimVersion', hints = 'off')
100+
except sh.ErrorReturnCode_1:
101+
HAS_NIM = False
102+
NIM_VERSION = None
103+
104+
assert not HAS_NIM or version.parse(NIM_VERSION) >= version.parse('2.0.8'), 'nim version must be 2.0.8 or above'
105+
90106
# check is github ci
91107

92108
IS_GITHUB_CI = env.bool('IS_GITHUB_CI', False)

0 commit comments

Comments
 (0)