Skip to content

Commit 6693b4b

Browse files
authored
feat: run thunder tests as part of LitGPT CI (#1975)
1 parent 7779471 commit 6693b4b

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

.azure/gpu-test.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,25 @@ jobs:
7575
displayName: "Env details"
7676
7777
- bash: |
78-
pytest -v --disable-pytest-warnings --strict-markers --color=yes \
78+
pytest -v \
7979
--ignore-glob="tests/test_thunder*" \
8080
--ignore="tests/test_unsloth_executor.py"
8181
displayName: "Ordinary tests"
8282
condition: ne(variables['dependency'], 'compiler')
8383
timeoutInMinutes: "5"
8484
8585
- bash: |
86-
pytest -v --disable-pytest-warnings --strict-markers --color=yes
86+
# install thunder from source, so that, thunder.tests will be available
87+
pip install -U "thunder[test] @ git+https://github.com/Lightning-AI/lightning-thunder.git"
88+
PL_RUN_CUDA_TESTS=0 pytest tests/ext_thunder/test_thunder_networks.py -v # without env var, it filters out all tests
89+
displayName: "Extra tests w. Thunder [main branch]"
90+
condition: eq(variables['dependency'], 'compiler')
91+
env:
92+
PL_RUN_CUDA_TESTS: "0"
93+
timeoutInMinutes: "10"
94+
95+
- bash: |
96+
pytest -v
8797
displayName: "All tests"
8898
condition: eq(variables['dependency'], 'compiler')
8999
timeoutInMinutes: "5"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ optional-dependencies.test = [
7575
"einops>=0.7",
7676
"protobuf>=4.23.4",
7777
"pytest>=8.1.1",
78+
"pytest-benchmark>=5.1",
7879
"pytest-dependency>=0.6",
7980
"pytest-rerunfailures>=14",
8081
"pytest-timeout>=2.3.1",
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Run thunder tests as part of LitGPT CI"""
2+
3+
from litgpt.utils import _THUNDER_AVAILABLE
4+
5+
if _THUNDER_AVAILABLE:
6+
from thunder.tests.test_networks import * # noqa: F403
7+
else:
8+
print("Skipping test_thunder_networks.py (thunder not available)")

0 commit comments

Comments
 (0)