Skip to content

Commit fef62f7

Browse files
Tests: skip torch.compile for cuda on windows
1 parent 3e8c348 commit fef62f7

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

.github/workflows/tests.yml

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,40 @@ jobs:
205205
torch_version: "2.7.0"
206206
pypi_index: "https://download.pytorch.org/whl/cu128"
207207

208-
# Add torch 2.7+cu118 for Windows.
209-
- os: windows-2025
210-
arch: x86_64
211-
gpu: T4
212-
runner: CUDA-Windows-x64
213-
cuda_version: "11.8.0"
214-
torch_version: "2.7.0"
215-
pypi_index: "https://download.pytorch.org/whl/cu118"
216208

217-
# L40S runners
209+
# Linux L40S runners
218210
- os: ubuntu-22.04
219211
gpu: L40S
220212
runner: bandb-aws-g6e-4xlarge-plus-use1-public-80
221213

222-
# T4 runners
214+
# Linux T4 runners
223215
- os: ubuntu-22.04
224216
gpu: T4
225217
runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80
218+
219+
# Specific Windows runners using cu118
220+
- os: windows-2025
221+
arch: x86_64
222+
gpu: T4
223+
runner: CUDA-Windows-x64
224+
cuda_version: "11.8.0"
225+
torch_version: "2.2.0"
226+
pypi_index: "https://download.pytorch.org/whl/cu118"
226227
- os: windows-2025
228+
arch: x86_64
227229
gpu: T4
228230
runner: CUDA-Windows-x64
231+
cuda_version: "11.8.0"
232+
torch_version: "2.6.0"
233+
pypi_index: "https://download.pytorch.org/whl/cu118"
234+
- os: windows-2025
235+
arch: x86_64
236+
gpu: T4
237+
runner: CUDA-Windows-x64
238+
cuda_version: "11.8.0"
239+
torch_version: "2.7.0"
240+
pypi_index: "https://download.pytorch.org/whl/cu118"
241+
229242
exclude:
230243
# Our current T4 Windows runner has a driver too old (471.11)
231244
# and cannot support CUDA 12+. Skip for now.

tests/test_linear4bit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
300300
if fullgraph and torch.__version__ < (2, 8):
301301
pytest.skip("fullgraph mode requires torch 2.8 or higher")
302302

303+
if device == "cuda" and platform.system() == "Windows":
304+
pytest.skip("Triton is not officially supported on Windows")
305+
303306
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
304307
if (
305308
not fullgraph

tests/test_linear8bitlt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ def test_linear8bit_serialization(linear8bit):
234234
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
235235
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
236236
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
237+
if device == "cuda" and platform.system() == "Windows":
238+
pytest.skip("Triton is not officially supported on Windows")
239+
237240
dim = 256
238241
batch_size = 16
239242

0 commit comments

Comments
 (0)