Skip to content

Commit ba51d95

Browse files
committed
Added more extensive gemv tests; blocksize guard for gemv.
1 parent b8da4a1 commit ba51d95

File tree

5 files changed

+122
-69
lines changed

5 files changed

+122
-69
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import dataclass
44
from functools import reduce # Required in Python 3
55
from typing import Tuple, Optional, List
6+
from warnings import warn
67

78
import torch
89

@@ -565,6 +566,11 @@ def matmul(
565566
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
566567
assert quant_state is not None
567568
if A.numel() == A.shape[-1] and A.requires_grad == False:
568-
return F.gemv_4bit(A, B.t(), out, state=quant_state)
569+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
570+
if A.shape[-1] % blocksize != 0:
571+
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
572+
return MatMul4Bit.apply(A, B, out, bias, quant_state)
573+
else:
574+
return F.gemv_4bit(A, B.t(), out, state=quant_state)
569575
else:
570576
return MatMul4Bit.apply(A, B, out, bias, quant_state)

bitsandbytes/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,7 @@ def gemv_4bit(
15041504
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
15051505
else:
15061506
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
1507+
15071508
else:
15081509
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
15091510

csrc/kernels.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ __device__ half dhDequantizeNF4(unsigned char val)
222222

223223
__device__ float dDequantizeNF4(unsigned char val)
224224
{
225+
225226
// the values for this tree was generated by test_normal_map_tree
226227
// in the file tests/test_functional.py
227228
if((val & 0b1000) == 8)
@@ -3526,10 +3527,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35263527
{
35273528

35283529
// per threadblock:
3529-
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
3530-
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
3530+
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
35313531
// 4 warps -> 4 loads per iter
3532-
// 1x128 * 128x4 -> 1x4 outputs
3532+
// 1x32 * 32x4 -> 1x4 outputs per thread block
35333533
typedef cub::WarpReduce<float> WarpReduce;
35343534
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
35353535

@@ -3547,7 +3547,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35473547

35483548
for(int i = threadIdx.x; i < 16; i++)
35493549
quant_map[i] = T(datatype[i]);
3550-
35513550
__syncthreads();
35523551

35533552
// A: [1, K]
@@ -3563,6 +3562,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35633562
{
35643563
if((inner_idx_halved + num_values_8bit) < (K/2))
35653564
{
3565+
// this is the most important for performance considerations
35663566
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
35673567
}
35683568
else
@@ -3597,6 +3597,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35973597

35983598
if(inner_idx+num_values_4bit < K)
35993599
{
3600+
// this is also relatively important for performance
36003601
if(BITS==16)
36013602
{
36023603
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 0];
@@ -3618,13 +3619,15 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
36183619

36193620
}
36203621
else
3622+
#pragma unroll
36213623
for(int k = 0; k < num_values_4bit; k++)
36223624
if(inner_idx + k < K)
36233625
local_A[k] = A[inner_idx + k];
36243626
else
36253627
local_A[k] = T(0.0f);
36263628

36273629

3630+
// accumulate in float; small performance hit for Ampere, but lower error for outputs
36283631
#pragma unroll
36293632
for(int k = 0; k < num_values_4bit; k++)
36303633
{

csrc/ops.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
735735
int num_blocks = (m+3)/4;
736736

737737
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
738+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
738739
}
739740

740741
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)

tests/test_functional.py

Lines changed: 106 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype):
22622262
A2 = F.dequantize_fp4(qa, SA)
22632263

22642264
err = (A1 - A2).abs().float()
2265-
relerr = (err/A1.abs().float()).mean()
2265+
relerr = (err/(A1.abs().float()+1e-8)).mean()
22662266
idx = err > 1.0
22672267
err = err.mean()
22682268

@@ -2361,91 +2361,133 @@ def test_normal_map_tree():
23612361

23622362
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
23632363
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
2364+
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
23642365
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
2365-
def test_gemv_4bit(dtype, storage_type, double_quant):
2366-
print('')
2367-
for dim in [128, 256, 512, 1024, 2048, 4096]:
2366+
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
2367+
for dim in [128, 256, 512, 1024, 2048, 4096, 6144]:
23682368
#for dim in [4*1024]:
2369-
#for dim in [1*16]:
2370-
errs = []
2371-
relerrs = []
2372-
max_err = 0
2373-
max_relerr = 0
2369+
#for dim in [1*128]:
2370+
errs1 = []
2371+
errs2 = []
2372+
errs3 = []
2373+
relerrs1 = []
2374+
relerrs2 = []
2375+
relerrs3 = []
2376+
max_errs1 = []
2377+
max_errs2 = []
2378+
max_errs3 = []
2379+
23742380

23752381
for i in range(100):
2376-
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
2377-
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
2378-
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2379-
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2380-
A = torch.randn(1, dim, dtype=dtype, device='cuda')
2381-
#B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
2382-
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
2383-
2384-
#print('')
2385-
#print(A)
2386-
#print(B.t())
2387-
#A[:, :-1] = 0
2388-
#B[:, :-1] = 0
2389-
#A.flatten()[:-1] = 0
2390-
#B.flatten()[:-1] = 0
2382+
if kind == 'fc1':
2383+
A = torch.randn(1, dim, dtype=dtype, device='cuda')
2384+
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
2385+
elif kind == 'fc2':
2386+
A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
2387+
B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
2388+
elif kind == 'attn':
2389+
A = torch.randn(1, dim, dtype=dtype, device='cuda')
2390+
B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
2391+
elif kind == 'attn_packed':
2392+
A = torch.randn(1, dim, dtype=dtype, device='cuda')
2393+
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
23912394

23922395
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
2393-
#F.dequantize_4bit(qB, state)
2394-
23952396
C3 = torch.matmul(A, B.t())
23962397
C2 = F.gemv_4bit(A, qB.t(), state=state)
23972398
A.requires_grad = True
23982399
C1 = bnb.matmul_4bit(A, qB.t(), state)
23992400

2400-
#print(state)
2401-
#print(qB)
2401+
err1 = (C1-C2).abs().float()
2402+
err2 = (C3-C2).abs().float()
2403+
err3 = (C3-C1).abs().float()
2404+
2405+
mag1 = torch.abs(C1).float()+1e-5
2406+
mag2 = torch.abs(C3).float()+1e-5
2407+
mag3 = torch.abs(C3).float()+1e-5
2408+
2409+
relerr1 = err1/mag1
2410+
relerr2 = err2/mag2
2411+
relerr3 = err3/mag3
24022412

2403-
#print('')
2404-
#print(A)
2405-
#print(B)
2406-
#print('='*89)
2407-
#print(C3)
2413+
max_err1 = err1.max()
2414+
max_err2 = err2.max()
2415+
max_err3 = err3.max()
24082416

2409-
#print(C1.shape, C2.shape)
2417+
errs1.append(err1.mean().item())
2418+
errs2.append(err2.mean().item())
2419+
errs3.append(err3.mean().item())
24102420

2411-
# tensor cores are non-deterministic
2412-
# so we need to analyze errors around the mean
2413-
# to test our implementation
2414-
err = torch.abs(C1-C2).float()
2415-
mag = torch.abs(C1).float()+1e-5
2416-
relerr = err/mag
2417-
max_err = max(err.max(), max_err)
2418-
max_relerr = max(relerr.max(), max_relerr)
2419-
err = err.mean().item()
2420-
relerr = relerr.mean().item()
2421-
#print(err)
2421+
relerrs1.append(relerr1.mean().item())
2422+
relerrs2.append(relerr2.mean().item())
2423+
relerrs3.append(relerr3.mean().item())
24222424

2423-
errs.append(err)
2424-
relerrs.append(relerr)
2425+
max_errs1.append(max_err1.item())
2426+
max_errs2.append(max_err2.item())
2427+
max_errs3.append(max_err3.item())
24252428

24262429
c = int(C1.numel()*0.0014*(dim/256))+1
24272430

24282431
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
2429-
#print('')
2430-
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
2431-
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
2432-
#print(dim, (max_err.item(), max_relerr.item()))
2433-
print(C1.flatten()[-20:])
2434-
print(C2.flatten()[-20:])
2435-
#print(C1.flatten())
2436-
#print(C2.flatten())
2437-
#print(C3.flatten()[-20:])
2438-
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
2439-
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
2432+
err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
2433+
err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
2434+
err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
2435+
relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
2436+
relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
2437+
relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
2438+
maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
2439+
maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
2440+
maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
2441+
absratio = err2/err3
2442+
relratio = relerr2/relerr3
2443+
maxratio = relerr2/relerr3
2444+
2445+
# for debugging if the tests fails
2446+
#
2447+
#print('='*80)
2448+
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
2449+
#print(C1.flatten()[-20:])
2450+
#print(C2.flatten()[-20:])
2451+
#print(f'inference vs training abs: {err1}')
2452+
#print(f'inference vs training rel: {relerr1}')
2453+
#print(f'inference vs training max: {maxerr1}')
2454+
#print(f'inference vs training vs torch err ratio abs: {absratio}')
2455+
#print(f'inference vs training vs torch err ratio rel: {relratio}')
2456+
#print(f'inference vs training vs torch err ratio max: {maxratio}')
24402457
if dtype == torch.float16:
2441-
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
2442-
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
2458+
if dim <= 512:
2459+
assert err1 < 7e-5
2460+
assert relerr1 < 0.0008
2461+
else:
2462+
assert err1 < 6e-5
2463+
assert relerr1 < 2e-4
2464+
assert absratio < 1.005 and absratio > 0.995
2465+
assert relratio < 1.005 and relratio > 0.995
2466+
assert maxratio < 1.005 and maxratio > 0.995
24432467
elif dtype == torch.float32:
2444-
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-8
2445-
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 1e-7
2468+
if dim <= 512:
2469+
assert err1 < 5e-8
2470+
assert relerr1 < 1e-6
2471+
assert maxerr1 < 1e-7
2472+
else:
2473+
assert err1 < 5e-8
2474+
assert relerr1 < 8e-6
2475+
assert maxerr1 < 1e-7
2476+
assert absratio < 1.005 and absratio > 0.995
2477+
assert relratio < 1.005 and relratio > 0.995
2478+
assert maxratio < 1.005 and maxratio > 0.995
24462479
elif dtype == torch.bfloat16:
2447-
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
2448-
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
2480+
if dim <= 512:
2481+
assert err1 < 5e-4
2482+
assert relerr1 < 0.007
2483+
assert maxerr1 < 0.015
2484+
else:
2485+
assert err1 < 2e-4
2486+
assert relerr1 < 0.002
2487+
assert maxerr1 < 0.0012
2488+
assert absratio < 1.005 and absratio > 0.995
2489+
assert relratio < 1.04 and relratio > 0.96
2490+
assert maxratio < 1.02 and maxratio > 0.98
24492491

24502492
@pytest.mark.skip("Row scale has some bugs for ampere")
24512493
def test_managed():

0 commit comments

Comments
 (0)