Skip to content

Commit f9bfea8

Browse files
committed
Baseline for debugging.
1 parent 7bfa09d commit f9bfea8

File tree

4 files changed

+66
-19
lines changed

4 files changed

+66
-19
lines changed

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,7 @@ def cutlass3_gemm(
14671467
lda = Bshape[1]
14681468
ldc = Bshape[0]
14691469
ldb = (ldb+1)//2
1470-
print(m, n, k, lda, ldb, ldc)
1470+
#print(m, n, k, lda, ldb, ldc)
14711471
is_on_gpu([B, A, out])
14721472
m = ct.c_int32(m)
14731473
n = ct.c_int32(n)

csrc/kernels.cu

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,9 +3061,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30613061
T local_A[1];
30623062
T local_B[32];
30633063

3064-
const int a_tile_offset = (8*16 + 16);
3065-
const int b_tile_offset = (16*32 + 16);
3066-
const int c_tile_offset = 8*32 + 24;
3064+
const int a_tile_offset = (8*16);
3065+
const int b_tile_offset = (16*32);
30673066

30683067
__shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))];
30693068
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
@@ -3109,6 +3108,19 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31093108
for(int col = 0; col < 32; col++)
31103109
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
31113110
}
3111+
else if(warp_id < (WARPS-1))
3112+
{
3113+
local_A[0] = T(0.0);
3114+
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = T(0.0);
3115+
3116+
#pragma unroll 32
3117+
for(int col = 0; col < 32; col++)
3118+
local_B[col] = T(0.0f);
3119+
3120+
#pragma unroll 32
3121+
for(int col = 0; col < 32; col++)
3122+
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = T(0.0f);
3123+
}
31123124
ticktock = ticktock == 0 ? 1 : 0;
31133125

31143126
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32)
@@ -3130,6 +3142,19 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31303142
for(int col = 0; col < 32; col++)
31313143
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
31323144
}
3145+
else if(warp_id < (WARPS-1))
3146+
{
3147+
local_A[0] = T(0.0);
3148+
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
3149+
3150+
#pragma unroll 32
3151+
for(int col = 0; col < 32; col++)
3152+
local_B[col] = 0.0f;
3153+
3154+
#pragma unroll 32
3155+
for(int col = 0; col < 32; col++)
3156+
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
3157+
}
31333158
ticktock = ticktock == 0 ? 1 : 0;
31343159

31353160
if(warp_id == (WARPS-1))

csrc/ops.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,14 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
680680

681681
int num_blocks = (m+31)/32;
682682

683-
cout << num_blocks << endl;
684-
cout << lda << endl;
685-
cout << ldb << endl;
686-
cout << ldc << endl;
687-
688-
cout << m << endl;
689-
cout << n << endl;
690-
cout << k << endl;
683+
//cout << num_blocks << endl;
684+
//cout << lda << endl;
685+
//cout << ldb << endl;
686+
//cout << ldc << endl;
687+
688+
//cout << m << endl;
689+
//cout << n << endl;
690+
//cout << k << endl;
691691
//if(bits == 32)
692692
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
693693
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);

tests/test_functional.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,25 +2355,47 @@ def test_normal_map_tree():
23552355
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
23562356
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
23572357
def test_cutlass3_gemm(dtype):
2358-
for i in range(1):
2358+
for i in range(100):
23592359
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
23602360
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
23612361
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
23622362
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2363-
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2364-
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2363+
A = torch.randn(1, 128+32, dtype=dtype, device='cuda')
2364+
B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128)
23652365

23662366
#print('')
23672367
#print(A)
23682368
#print(B.t())
2369+
#A[:, :-3] = 0
2370+
#B[:, :-3] = 0
23692371

23702372

23712373
C1 = torch.matmul(A, B.t())
23722374
C2 = F.cutlass3_gemm(A, B.t())
2373-
print(C1)
2374-
print(C2)
2375-
2376-
torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.06)
2375+
err = C1-C2
2376+
2377+
# tensor cores are non-deterministic
2378+
# so we need to analyze errors around the mean
2379+
# to test our implementation
2380+
err = torch.abs(err.mean()).item()
2381+
mag = torch.abs(C1).mean()
2382+
relerr = err/mag
2383+
2384+
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
2385+
print('')
2386+
print(i, err, mag.item(), relerr.item())
2387+
print(A.flatten()[-6:])
2388+
print(B.flatten()[-6:])
2389+
out = A.flatten()[-6:]*B.flatten()[-6:]
2390+
print(out)
2391+
print(out[:-1].sum())
2392+
print('='*80)
2393+
print(C1.flatten()[-6:])
2394+
print(C2.flatten()[-6:])
2395+
#assert False, 'ERROR'
2396+
2397+
c = int(C1.numel()*0.001)
2398+
assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
23772399

23782400
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
23792401
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])

0 commit comments

Comments
 (0)