Skip to content

Commit 9c49d09

Browse files
Tests: WA aarch64 CPU regressions for torch 2.6.0; add Windows torch==2.7.0+cu118 test config
1 parent f176eab commit 9c49d09

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

.github/workflows/tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ jobs:
129129
with:
130130
python-version: 3.9
131131

132+
- name: Setup MSVC
133+
if: startsWith(matrix.os, 'windows')
134+
uses: ilammy/[email protected] # to use cl for torch.compile
135+
132136
- name: Install dependencies
133137
run: |
134138
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
@@ -188,6 +192,15 @@ jobs:
188192
torch_version: "2.7.0"
189193
pypi_index: "https://download.pytorch.org/whl/cu128"
190194

195+
# Add torch 2.7+cu118 for Windows.
196+
- os: windows-2025
197+
arch: x86_64
198+
gpu: T4
199+
runner: CUDA-Windows-x64
200+
cuda_version: "11.8.0"
201+
torch_version: "2.7.0"
202+
pypi_index: "https://download.pytorch.org/whl/cu118"
203+
191204
# L40S runners
192205
- os: ubuntu-22.04
193206
gpu: L40S

tests/test_linear4bit.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import os
33
import pickle
4+
import platform
45
from tempfile import TemporaryDirectory
56

67
import pytest
@@ -299,6 +300,16 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
299300
if fullgraph and torch.__version__ < (2, 8):
300301
pytest.skip("fullgraph mode requires torch 2.8 or higher")
301302

303+
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
304+
if (
305+
not fullgraph
306+
and device == "cpu"
307+
and platform.machine() == "aarch64"
308+
and platform.system() == "Linux"
309+
and ((2, 7) > torch.__version__ >= (2, 6))
310+
):
311+
pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")
312+
302313
dim = 256
303314
batch_size = 16
304315

tests/test_linear8bitlt.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import os
44
import pickle
5+
import platform
56
from tempfile import TemporaryDirectory
67

78
import pytest
@@ -238,7 +239,6 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
238239

239240
torch.compiler.reset()
240241

241-
torch._dynamo.config.patch()
242242
# Create a small network with Linear8bitLt layers
243243
net = torch.nn.Sequential(
244244
*[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
@@ -267,7 +267,15 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
267267
torch.testing.assert_close(compiled_output, ref_output)
268268

269269
# Test with gradients. Currently only works with threshold=0.
270-
if threshold == 0:
270+
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
271+
is_broken_platform = (
272+
device == "cpu"
273+
and platform.machine() == "aarch64"
274+
and platform.system() == "Linux"
275+
and ((2, 7) > torch.__version__ >= (2, 6))
276+
)
277+
278+
if threshold == 0 and not is_broken_platform:
271279
x.requires_grad_(True)
272280
y1 = net(x).sum()
273281
y1.backward()

0 commit comments

Comments
 (0)