Skip to content

Commit c6ba287

Browse files
committed
Enable even more unit tests for warp size 32
1 parent 3f9f6f3 commit c6ba287

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

tests/test_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
491491
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
492492
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
493493
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
494-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
494+
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
495495
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
496496
def min_max(x):
497497
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -1205,7 +1205,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
12051205
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
12061206
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
12071207
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1208-
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
1208+
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
12091209
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
12101210
"""
12111211
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
@@ -1428,7 +1428,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14281428
@pytest.mark.parametrize("device", get_available_devices())
14291429
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
14301430
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1431-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
1431+
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
14321432
def test_gemv_eye_4bit(self, device, storage_type, dtype):
14331433
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
14341434
pytest.skip("This configuration is not supported on HPU.")

tests/test_linear8bitlt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111
import bitsandbytes as bnb
12-
from bitsandbytes.cextension import HIP_ENVIRONMENT
12+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1313
from bitsandbytes.nn.modules import Linear8bitLt
1414
from tests.helpers import (
1515
TRUE_FALSE,
@@ -234,7 +234,7 @@ def test_linear8bit_serialization(linear8bit):
234234
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
235235
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
236236
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
237-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
237+
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
238238
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
239239
if device == "cuda" and platform.system() == "Windows":
240240
pytest.skip("Triton is not officially supported on Windows")

tests/test_parametrize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44

55
from bitsandbytes import functional as F
6-
from bitsandbytes.cextension import HIP_ENVIRONMENT
6+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
77
from bitsandbytes.nn.parametrize import (
88
Bnb4bitParametrization,
99
replace_parameter_4bit,
@@ -39,7 +39,7 @@ def __init__(self, device="cpu", dtype=torch.float32):
3939
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
4040
@pytest.mark.parametrize(
4141
"blocksize",
42-
[64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
42+
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
4343
)
4444
def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize):
4545
"""Test basic parameter replacement with 4-bit quantization on different dtypes."""
@@ -267,7 +267,7 @@ def test_quant_state_preservation(device, dtype):
267267

268268
module = ParametrizeTestModule(device=device, dtype=dtype)
269269

270-
blocksize = 128 if HIP_ENVIRONMENT else 64
270+
blocksize = 128 if ROCM_WARP_SIZE_64 else 64
271271

272272
# Apply parametrization with specific settings
273273
replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)
@@ -326,7 +326,7 @@ def test_multiple_parameters(device, dtype):
326326
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
327327
@pytest.mark.parametrize(
328328
"blocksize",
329-
[64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
329+
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
330330
)
331331
def test_different_blocksizes(device, dtype, blocksize):
332332
"""Test parametrization with different block sizes to verify flexibility."""

0 commit comments

Comments
 (0)