Skip to content

Commit e83b610

Browse files
authored
Merge pull request #5 from RosettaCommons/add_workflow_for_tests
Add workflow for tests, had to add some code changes so tests can run on CPUs of GitHub runners.
2 parents 5ecca8d + 5e650cc commit e83b610

File tree

9 files changed

+446
-6
lines changed

9 files changed

+446
-6
lines changed

.github/workflows/test.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Run Tests
2+
3+
on:
4+
push:
5+
branches: ["main"]
6+
pull_request:
7+
branches: ["main"]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v4
15+
16+
# Set up micromamba
17+
- name: Set up micromamba
18+
uses: mamba-org/setup-micromamba@v2
19+
with:
20+
environment-file: rf_diffusion/environment/ci_environment.yml
21+
init-shell: bash
22+
cache-environment: true
23+
24+
- name: Install pytest
25+
shell: micromamba-shell {0}
26+
run: |
27+
python -m pip install pytest
28+
29+
- name: Download weights
30+
run: |
31+
mkdir weights
32+
curl -o weights/train_session2024-07-08_1720455712_BFF_3.00.pt https://files.ipd.uw.edu/pub/2025_RFDpoly/train_session2024-07-08_1720455712_BFF_3.00.pt
33+
34+
- name: Run tests
35+
shell: micromamba-shell {0}
36+
run: |
37+
python -m pytest test/test_demo.py

rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch
3030
import torch.nn.functional as F
3131
from torch import Tensor
32-
from torch.cuda.nvtx import range as nvtx_range
32+
from se3_transformer.utils.nvtx import nvtx_range
3333

3434
from se3_transformer.runtime.utils import degree_to_dim
3535

rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
3535
from se3_transformer.model.layers.linear import LinearSE3
3636
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
37-
from torch.cuda.nvtx import range as nvtx_range
37+
from se3_transformer.utils.nvtx import nvtx_range
3838

3939

4040
class AttentionSE3(nn.Module):

rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import torch.nn as nn
3232
from dgl import DGLGraph
3333
from torch import Tensor
34-
from torch.cuda.nvtx import range as nvtx_range
34+
from se3_transformer.utils.nvtx import nvtx_range
3535

3636
from se3_transformer.model.fiber import Fiber
3737
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features

rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch
2828
import torch.nn as nn
2929
from torch import Tensor
30-
from torch.cuda.nvtx import range as nvtx_range
30+
from se3_transformer.utils.nvtx import nvtx_range
3131

3232
from se3_transformer.model.fiber import Fiber
3333

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# se3_transformer/utils/nvtx.py
2+
from __future__ import annotations
3+
4+
from contextlib import contextmanager
5+
from typing import Iterator
6+
7+
@contextmanager
8+
def nvtx_range(message: str) -> Iterator[None]:
9+
"""
10+
Safe NVTX range context manager.
11+
12+
- If running with CUDA + NVTX support, emits real NVTX ranges.
13+
- Otherwise, becomes a no-op (CPU-only CI, ROCm-only builds, etc).
14+
"""
15+
try:
16+
import torch
17+
18+
if torch.cuda.is_available() and hasattr(torch.cuda, "nvtx"):
19+
try:
20+
from torch.cuda.nvtx import range as _nvtx_range
21+
with _nvtx_range(message):
22+
yield
23+
return
24+
except Exception:
25+
# CUDA available but NVTX missing/misconfigured -> fall back to no-op
26+
pass
27+
28+
yield
29+
except Exception:
30+
# torch not importable or other unexpected env issue -> no-op
31+
yield
32+
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
name: RFDpoly_env_ci_test
2+
channels:
3+
- pytorch
4+
- pyg
5+
- dglteam
6+
- conda-forge
7+
- bioconda
8+
- defaults
9+
dependencies:
10+
- python=3.10.8
11+
- pip=23.0.1
12+
- pytorch=1.13.1
13+
- cpuonly
14+
- pyg=2.2.0
15+
- pytorch-scatter=2.1.0
16+
- pytorch-sparse=0.6.16
17+
- pytorch-cluster=1.6.0
18+
- dgl=1.0.1
19+
- _libgcc_mutex=0.1
20+
- _openmp_mutex=4.5
21+
- anyio=3.5.0
22+
- appdirs=1.4.4
23+
- argon2-cffi=21.3.0
24+
- argon2-cffi-bindings=21.2.0
25+
- asttokens=2.0.5
26+
- attrs=22.1.0
27+
- babel=2.11.0
28+
- backcall=0.2.0
29+
- beautifulsoup4=4.11.1
30+
- blas=1.0
31+
- bleach=4.1.0
32+
- bottleneck=1.3.5
33+
- brotli=1.0.9
34+
- brotli-bin=1.0.9
35+
- brotlipy=0.7.0
36+
- bzip2=1.0.8
37+
- ca-certificates=2022.12.7
38+
- cairo=1.16.0
39+
- certifi=2022.12.7
40+
- cffi=1.15.1
41+
- charset-normalizer=2.0.4
42+
- comm=0.1.2
43+
- conda=23.1.0
44+
- conda-content-trust=0.1.3
45+
- conda-package-handling=2.0.2
46+
- conda-package-streaming=0.7.0
47+
- contourpy=1.0.5
48+
- cryptography=39.0.1
49+
- cycler=0.11.0
50+
- dbus=1.13.18
51+
- debugpy=1.5.1
52+
- decorator=5.1.1
53+
- defusedxml=0.7.1
54+
- entrypoints=0.4
55+
- executing=0.8.3
56+
- expat=2.4.9
57+
- flit-core=3.6.0
58+
- fontconfig=2.14.1
59+
- fonttools=4.25.0
60+
- freetype=2.12.1
61+
- giflib=5.2.1
62+
- glib=2.69.1
63+
- gst-plugins-base=1.14.1
64+
- gstreamer=1.14.1
65+
- icu=58.2
66+
- idna=3.4
67+
- intel-openmp=2021.4.0
68+
- ipykernel=6.19.2
69+
- ipython=8.10.0
70+
- ipython_genutils=0.2.0
71+
- jedi=0.18.1
72+
- jinja2=3.1.2
73+
- joblib=1.1.1
74+
- jpeg=9e
75+
- json5=0.9.6
76+
- jsonschema=4.17.3
77+
- jupyter_client=7.4.9
78+
- jupyter_core=5.2.0
79+
- jupyter_server=1.23.4
80+
- jupyterlab=3.5.3
81+
- jupyterlab_pygments=0.1.2
82+
- jupyterlab_server=2.19.0
83+
- kiwisolver=1.4.4
84+
- krb5=1.19.4
85+
- lcms2=2.12
86+
- ld_impl_linux-64=2.38
87+
- lerc=3.0
88+
- libbrotlicommon=1.0.9
89+
- libbrotlidec=1.0.9
90+
- libbrotlienc=1.0.9
91+
- libclang=10.0.1
92+
- libdeflate=1.17
93+
- libedit=3.1.20221030
94+
- libevent=2.1.12
95+
- libffi=3.4.2
96+
- libgcc-ng=12.2.0
97+
- libgfortran-ng=11.2.0
98+
- libgfortran5=11.2.0
99+
- libiconv=1.17
100+
- libllvm10=10.0.1
101+
- libpng=1.6.39
102+
- libpq=12.9
103+
- libsodium=1.0.18
104+
- libstdcxx-ng=11.2.0
105+
- libtiff=4.5.0
106+
- libuuid=1.41.5
107+
- libwebp=1.2.4
108+
- libwebp-base=1.2.4
109+
- libxcb=1.15
110+
- libxkbcommon=1.0.1
111+
- libxml2=2.9.14
112+
- libxslt=1.1.35
113+
- libzlib=1.2.13
114+
- llvm-openmp=15.0.7
115+
- lxml=4.9.1
116+
- lz4-c=1.9.4
117+
- markupsafe=2.1.1
118+
- matplotlib=3.7.0
119+
- matplotlib-base=3.7.0
120+
- matplotlib-inline=0.1.6
121+
- mistune=0.8.4
122+
- mkl=2021.4.0
123+
- mkl-service=2.4.0
124+
- mkl_fft=1.3.1
125+
- mkl_random=1.2.2
126+
- munkres=1.1.4
127+
- nbclassic=0.5.2
128+
- nbclient=0.5.13
129+
- nbconvert=6.5.4
130+
- nbformat=5.7.0
131+
- ncurses=6.4
132+
- nest-asyncio=1.5.6
133+
- networkx=2.8.4
134+
- notebook=6.5.2
135+
- notebook-shim=0.2.2
136+
- nspr=4.33
137+
- nss=3.74
138+
- numexpr=2.8.4
139+
- numpy=1.23.5
140+
- numpy-base=1.23.5
141+
- openbabel=3.1.1
142+
- openssl=1.1.1t
143+
- packaging=22.0
144+
- pandas=1.5.3
145+
- pandocfilters=1.5.0
146+
- parso=0.8.3
147+
- pcre=8.45
148+
- pexpect=4.8.0
149+
- pickleshare=0.7.5
150+
- pillow=9.4.0
151+
- pip=23.0.1
152+
- pixman=0.40.0
153+
- platformdirs=2.5.2
154+
- pluggy=1.0.0
155+
- ply=3.11
156+
- pooch=1.4.0
157+
- prometheus_client=0.14.1
158+
- prompt-toolkit=3.0.36
159+
- psutil=5.9.0
160+
- ptyprocess=0.7.0
161+
- pure_eval=0.2.2
162+
- pycosat=0.6.4
163+
- pycparser=2.21
164+
- pyg=2.2.0
165+
- pygments=2.11.2
166+
- pyopenssl=23.0.0
167+
- pyparsing=3.0.9
168+
- pyqt=5.15.7
169+
- pyrsistent=0.18.0
170+
- pysocks=1.7.1
171+
- python-dateutil=2.8.2
172+
- python-fastjsonschema=2.16.2
173+
- python_abi=3.10
174+
- pytz=2022.7
175+
- pyzmq=23.2.0
176+
- qt-main=5.15.2
177+
- qt-webengine=5.15.9
178+
- qtwebkit=5.212
179+
- readline=8.2
180+
- requests=2.28.1
181+
- ruamel.yaml=0.17.21
182+
- ruamel.yaml.clib=0.2.6
183+
- scikit-learn=1.2.1
184+
- scipy=1.10.0
185+
- seaborn=0.12.2
186+
- send2trash=1.8.0
187+
- setuptools=65.5.0
188+
- sip=6.6.2
189+
- six=1.16.0
190+
- sniffio=1.2.0
191+
- soupsieve=2.3.2.post1
192+
- sqlite=3.40.1
193+
- stack_data=0.2.0
194+
- terminado=0.17.1
195+
- threadpoolctl=2.2.0
196+
- tinycss2=1.2.1
197+
- tk=8.6.12
198+
- toml=0.10.2
199+
- tomli=2.0.1
200+
- toolz=0.12.0
201+
- tornado=6.2
202+
- tqdm=4.64.1
203+
- traitlets=5.7.1
204+
- typing-extensions=4.4.0
205+
- typing_extensions=4.4.0
206+
- tzdata=2022g
207+
- urllib3=1.26.14
208+
- wcwidth=0.2.5
209+
- webencodings=0.5.1
210+
- websocket-client=0.58.0
211+
- wheel=0.37.1
212+
- xz=5.2.10
213+
- zeromq=4.3.4
214+
- zlib=1.2.13
215+
- zstandard=0.19.0
216+
- zstd=1.5.2
217+
- pip:
218+
- antlr4-python3-runtime==4.9.3
219+
- assertpy==1.1
220+
- click==8.1.3
221+
- colorama==0.4.6
222+
- deepdiff==6.2.3
223+
- docker-pycreds==0.4.0
224+
- e3nn==0.5.1
225+
- gitdb==4.0.10
226+
- gitpython==3.1.31
227+
- hydra-core==1.3.2
228+
- mpmath==1.3.0
229+
- omegaconf==2.3.0
230+
- opt-einsum==3.3.0
231+
- opt-einsum-fx==0.1.4
232+
- ordered-set==4.1.0
233+
- orjson==3.8.7
234+
- pathtools==0.1.2
235+
- protobuf==4.22.1
236+
- pyqt5-sip==12.11.0
237+
- pyyaml==6.0
238+
- sentry-sdk==1.16.0
239+
- setproctitle==1.3.2
240+
- smmap==5.0.0
241+
- sympy==1.11.1
242+
- wandb==0.13.11

rf_diffusion/environment/macos_environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: RFDpoly_env
1+
name: RFDpoly_env_macos
22
channels:
33
- pytorch
44
- conda-forge
@@ -18,4 +18,4 @@ dependencies:
1818
- pip:
1919
- dgl==1.0.1
2020
- e3nn==0.5.1
21-
- hydra-core==1.3.2
21+
- hydra-core==1.3.2

0 commit comments

Comments
 (0)