Skip to content

Commit 9f85829

Browse files
Add torch.compile tests (#1648)
* Add torch.compile tests * Tests: WA aarch64 CPU regressions for torch 2.6.0; add Windows torch==2.7.0+cu118 test config * Tests: skip torch.compile for cuda on windows
1 parent 503d243 commit 9f85829

File tree

5 files changed

+188
-6
lines changed

5 files changed

+188
-6
lines changed

.github/workflows/tests.yml

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ jobs:
137137
with:
138138
python-version: 3.9
139139

140+
- name: Setup MSVC
141+
if: startsWith(matrix.os, 'windows')
142+
uses: ilammy/[email protected] # to use cl for torch.compile
143+
140144
- name: Install dependencies
141145
run: |
142146
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
@@ -201,18 +205,40 @@ jobs:
201205
torch_version: "2.7.0"
202206
pypi_index: "https://download.pytorch.org/whl/cu128"
203207

204-
# L40S runners
208+
209+
# Linux L40S runners
205210
- os: ubuntu-22.04
206211
gpu: L40S
207212
runner: bandb-aws-g6e-4xlarge-plus-use1-public-80
208213

209-
# T4 runners
214+
# Linux T4 runners
210215
- os: ubuntu-22.04
211216
gpu: T4
212217
runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80
218+
219+
# Specific Windows runners using cu118
220+
- os: windows-2025
221+
arch: x86_64
222+
gpu: T4
223+
runner: CUDA-Windows-x64
224+
cuda_version: "11.8.0"
225+
torch_version: "2.2.0"
226+
pypi_index: "https://download.pytorch.org/whl/cu118"
213227
- os: windows-2025
228+
arch: x86_64
229+
gpu: T4
230+
runner: CUDA-Windows-x64
231+
cuda_version: "11.8.0"
232+
torch_version: "2.6.0"
233+
pypi_index: "https://download.pytorch.org/whl/cu118"
234+
- os: windows-2025
235+
arch: x86_64
214236
gpu: T4
215237
runner: CUDA-Windows-x64
238+
cuda_version: "11.8.0"
239+
torch_version: "2.7.0"
240+
pypi_index: "https://download.pytorch.org/whl/cu118"
241+
216242
exclude:
217243
# Our current T4 Windows runner has a driver too old (471.11)
218244
# and cannot support CUDA 12+. Skip for now.

bitsandbytes/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,14 +771,14 @@ def quantize_blockwise(
771771
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
772772
quant_state = QuantState(
773773
absmax=qabsmax,
774-
code=code,
774+
code=code.to(A.device, copy=True),
775775
blocksize=blocksize,
776776
dtype=A.dtype,
777777
offset=offset,
778778
state2=state2,
779779
)
780780
else:
781-
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
781+
quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype)
782782

783783
# TODO(matthewdouglas): Deprecate out kwarg
784784
out = out.copy_(_out) if out is not None else _out

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def forward(self, x: torch.Tensor):
493493

494494
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
495495

496-
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
496+
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
497497

498498

499499
class LinearFP4(Linear4bit):

tests/test_linear4bit.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
import copy
22
import os
33
import pickle
4+
import platform
45
from tempfile import TemporaryDirectory
56

67
import pytest
78
import torch
89

910
import bitsandbytes as bnb
10-
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
11+
from tests.helpers import (
12+
TRUE_FALSE,
13+
describe_dtype,
14+
get_available_devices,
15+
id_formatter,
16+
torch_load_from_buffer,
17+
torch_save_to_buffer,
18+
)
1119

1220
storage = {
1321
"uint8": torch.uint8,
@@ -275,3 +283,85 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
275283
# there was a bug where deepcopy would modify the original object
276284
assert dict_keys_before == dict_keys_after
277285
assert dict_keys_before == dict_keys_deserialized
286+
287+
288+
@pytest.mark.parametrize("device", get_available_devices())
289+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
290+
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
291+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
292+
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
293+
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
294+
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
295+
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
296+
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
297+
if device == "cpu" and quant_type == "fp4":
298+
pytest.skip("FP4 is not supported for CPU")
299+
300+
if fullgraph and torch.__version__ < (2, 8):
301+
pytest.skip("fullgraph mode requires torch 2.8 or higher")
302+
303+
if device == "cuda" and platform.system() == "Windows":
304+
pytest.skip("Triton is not officially supported on Windows")
305+
306+
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
307+
if (
308+
not fullgraph
309+
and device == "cpu"
310+
and platform.machine() == "aarch64"
311+
and platform.system() == "Linux"
312+
and ((2, 7) > torch.__version__ >= (2, 6))
313+
):
314+
pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")
315+
316+
dim = 256
317+
batch_size = 16
318+
319+
torch.compiler.reset()
320+
321+
# Create a small network with Linear4bit layers
322+
net = torch.nn.Sequential(
323+
*[
324+
bnb.nn.Linear4bit(
325+
dim,
326+
dim,
327+
bias=bias,
328+
compute_dtype=compute_dtype,
329+
compress_statistics=compress_statistics,
330+
quant_type=quant_type,
331+
)
332+
for _ in range(4)
333+
]
334+
).to(device)
335+
336+
# Create input tensor
337+
x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)
338+
339+
# Get reference output before compilation
340+
with torch.no_grad():
341+
ref_output = net(x)
342+
343+
# Compile the model
344+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
345+
346+
# Get output from compiled model
347+
with torch.no_grad():
348+
compiled_output = compiled_net(x)
349+
350+
# Check outputs match
351+
assert compiled_output.shape == ref_output.shape
352+
assert compiled_output.device == ref_output.device
353+
assert compiled_output.dtype == ref_output.dtype
354+
torch.testing.assert_close(compiled_output, ref_output)
355+
356+
# Test with gradients
357+
x.requires_grad_(True)
358+
y1 = net(x).sum()
359+
y1.backward()
360+
grad_ref = x.grad.clone()
361+
362+
x.grad = None
363+
y2 = compiled_net(x).sum()
364+
y2.backward()
365+
grad_compiled = x.grad.clone()
366+
367+
torch.testing.assert_close(grad_compiled, grad_ref)

tests/test_linear8bitlt.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import os
44
import pickle
5+
import platform
56
from tempfile import TemporaryDirectory
67

78
import pytest
@@ -224,3 +225,68 @@ def test_linear8bit_serialization(linear8bit):
224225
# check for a bug where SCB and CB were not copied
225226
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
226227
assert (linear8bit.weight.CB == deserialized.weight.CB).all()
228+
229+
230+
@pytest.mark.parametrize("device", get_available_devices())
231+
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
232+
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
233+
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
234+
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
235+
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
236+
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
237+
if device == "cuda" and platform.system() == "Windows":
238+
pytest.skip("Triton is not officially supported on Windows")
239+
240+
dim = 256
241+
batch_size = 16
242+
243+
torch.compiler.reset()
244+
245+
# Create a small network with Linear8bitLt layers
246+
net = torch.nn.Sequential(
247+
*[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
248+
).to(device)
249+
250+
dynamic_output_shapes = fullgraph and threshold > 0
251+
with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes):
252+
# Create input tensor
253+
x = torch.randn(batch_size, dim, dtype=torch.float16, device=device)
254+
255+
# Get reference output before compilation
256+
with torch.no_grad():
257+
ref_output = net(x)
258+
259+
# Compile the model
260+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
261+
262+
# Get output from compiled model
263+
with torch.no_grad():
264+
compiled_output = compiled_net(x)
265+
266+
# Check outputs match
267+
assert compiled_output.shape == ref_output.shape
268+
assert compiled_output.device == ref_output.device
269+
assert compiled_output.dtype == ref_output.dtype
270+
torch.testing.assert_close(compiled_output, ref_output)
271+
272+
# Test with gradients. Currently only works with threshold=0.
273+
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
274+
is_broken_platform = (
275+
device == "cpu"
276+
and platform.machine() == "aarch64"
277+
and platform.system() == "Linux"
278+
and ((2, 7) > torch.__version__ >= (2, 6))
279+
)
280+
281+
if threshold == 0 and not is_broken_platform:
282+
x.requires_grad_(True)
283+
y1 = net(x).sum()
284+
y1.backward()
285+
grad_ref = x.grad.clone()
286+
287+
x.grad = None
288+
y2 = compiled_net(x).sum()
289+
y2.backward()
290+
grad_compiled = x.grad.clone()
291+
292+
torch.testing.assert_close(grad_compiled, grad_ref)

0 commit comments

Comments
 (0)