Skip to content

Commit 8f5d139

Browse files
authored
Enable more tests on AMD for warp size 32 (#1805)
* Enable even more unit tests for warp size 32 * Revert comment chagnes from previous PR for consistency
1 parent 3f9f6f3 commit 8f5d139

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

csrc/kernels.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2613,9 +2613,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
26132613
{
26142614

26152615
// per threadblock:
2616-
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
2616+
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
26172617
// 4 warps -> 4 loads per iter
2618-
// 1 x BNB_WARP_SIZE * BNB_WARP_SIZE x 4 -> 1x4 outputs per thread block
2618+
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
26192619
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
26202620
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];
26212621

tests/test_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
491491
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
492492
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
493493
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
494-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
494+
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
495495
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
496496
def min_max(x):
497497
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -1205,7 +1205,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
12051205
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
12061206
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
12071207
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1208-
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
1208+
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
12091209
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
12101210
"""
12111211
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
@@ -1428,7 +1428,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14281428
@pytest.mark.parametrize("device", get_available_devices())
14291429
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
14301430
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1431-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
1431+
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
14321432
def test_gemv_eye_4bit(self, device, storage_type, dtype):
14331433
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
14341434
pytest.skip("This configuration is not supported on HPU.")

tests/test_linear8bitlt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111
import bitsandbytes as bnb
12-
from bitsandbytes.cextension import HIP_ENVIRONMENT
12+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1313
from bitsandbytes.nn.modules import Linear8bitLt
1414
from tests.helpers import (
1515
TRUE_FALSE,
@@ -234,7 +234,7 @@ def test_linear8bit_serialization(linear8bit):
234234
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
235235
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
236236
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
237-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
237+
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
238238
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
239239
if device == "cuda" and platform.system() == "Windows":
240240
pytest.skip("Triton is not officially supported on Windows")

tests/test_parametrize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44

55
from bitsandbytes import functional as F
6-
from bitsandbytes.cextension import HIP_ENVIRONMENT
6+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
77
from bitsandbytes.nn.parametrize import (
88
Bnb4bitParametrization,
99
replace_parameter_4bit,
@@ -39,7 +39,7 @@ def __init__(self, device="cpu", dtype=torch.float32):
3939
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
4040
@pytest.mark.parametrize(
4141
"blocksize",
42-
[64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
42+
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
4343
)
4444
def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize):
4545
"""Test basic parameter replacement with 4-bit quantization on different dtypes."""
@@ -267,7 +267,7 @@ def test_quant_state_preservation(device, dtype):
267267

268268
module = ParametrizeTestModule(device=device, dtype=dtype)
269269

270-
blocksize = 128 if HIP_ENVIRONMENT else 64
270+
blocksize = 128 if ROCM_WARP_SIZE_64 else 64
271271

272272
# Apply parametrization with specific settings
273273
replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)
@@ -326,7 +326,7 @@ def test_multiple_parameters(device, dtype):
326326
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
327327
@pytest.mark.parametrize(
328328
"blocksize",
329-
[64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
329+
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
330330
)
331331
def test_different_blocksizes(device, dtype, blocksize):
332332
"""Test parametrization with different block sizes to verify flexibility."""

0 commit comments

Comments
 (0)