Skip to content

Commit 9192c9d

Browse files
committed
Tighter and scaled error analysis.
1 parent f9bfea8 commit 9192c9d

File tree

2 files changed

+70
-42
lines changed

2 files changed

+70
-42
lines changed

csrc/kernels.cu

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3123,6 +3123,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31233123
}
31243124
ticktock = ticktock == 0 ? 1 : 0;
31253125

3126+
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
31263127
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32)
31273128
{
31283129
idx = base_idx + threadIdx.x;
@@ -3155,8 +3156,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31553156
for(int col = 0; col < 32; col++)
31563157
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
31573158
}
3158-
ticktock = ticktock == 0 ? 1 : 0;
3159+
//ticktock = ticktock == 0 ? 1 : 0;
31593160

3161+
__syncthreads();
31603162
if(warp_id == (WARPS-1))
31613163
for(int k = 0; k < batch_size_warps; k++)
31623164
{
@@ -3166,11 +3168,22 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31663168
}
31673169
}
31683170

3171+
//__syncthreads();
3172+
//if(warp_id == (WARPS-1))
3173+
// for(int k = 0; k < batch_size_warps; k++)
3174+
// {
3175+
// wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
3176+
// wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
3177+
// wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
3178+
// }
3179+
__syncthreads();
3180+
31693181
// 129 mu
31703182
if(warp_id == (WARPS-1))
31713183
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
31723184
__syncthreads();
31733185

3186+
31743187
//if(threadIdx.x >= 16){ return; }
31753188
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
31763189

tests/test_functional.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,47 +2355,62 @@ def test_normal_map_tree():
23552355
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
23562356
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
23572357
def test_cutlass3_gemm(dtype):
2358-
for i in range(100):
2359-
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
2360-
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
2361-
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2362-
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2363-
A = torch.randn(1, 128+32, dtype=dtype, device='cuda')
2364-
B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128)
2365-
2366-
#print('')
2367-
#print(A)
2368-
#print(B.t())
2369-
#A[:, :-3] = 0
2370-
#B[:, :-3] = 0
2371-
2372-
2373-
C1 = torch.matmul(A, B.t())
2374-
C2 = F.cutlass3_gemm(A, B.t())
2375-
err = C1-C2
2376-
2377-
# tensor cores are non-deterministic
2378-
# so we need to analyze errors around the mean
2379-
# to test our implementation
2380-
err = torch.abs(err.mean()).item()
2381-
mag = torch.abs(C1).mean()
2382-
relerr = err/mag
2383-
2384-
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
2385-
print('')
2386-
print(i, err, mag.item(), relerr.item())
2387-
print(A.flatten()[-6:])
2388-
print(B.flatten()[-6:])
2389-
out = A.flatten()[-6:]*B.flatten()[-6:]
2390-
print(out)
2391-
print(out[:-1].sum())
2392-
print('='*80)
2393-
print(C1.flatten()[-6:])
2394-
print(C2.flatten()[-6:])
2395-
#assert False, 'ERROR'
2396-
2397-
c = int(C1.numel()*0.001)
2398-
assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
2358+
for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
2359+
errs = []
2360+
relerrs = []
2361+
max_err = 0
2362+
max_relerr = 0
2363+
for i in range(100):
2364+
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
2365+
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
2366+
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2367+
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2368+
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
2369+
B = torch.randn(4*496, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
2370+
2371+
#print('')
2372+
#print(A)
2373+
#print(B.t())
2374+
#A[:, :-3] = 0
2375+
#B[:, :-3] = 0
2376+
2377+
2378+
C1 = torch.matmul(A, B.t())
2379+
C2 = F.cutlass3_gemm(A, B.t())
2380+
2381+
# tensor cores are non-deterministic
2382+
# so we need to analyze errors around the mean
2383+
# to test our implementation
2384+
err = torch.abs(C1-C2)
2385+
mag = torch.abs(C1)+1e-8
2386+
relerr = err/mag
2387+
max_err = max(err.max(), max_err)
2388+
max_relerr = max(relerr.max(), max_relerr)
2389+
err = err.mean().item()
2390+
relerr = relerr.mean().item()
2391+
2392+
errs.append(err)
2393+
relerrs.append(relerr)
2394+
2395+
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
2396+
# print('')
2397+
# print(i, err, mag.item(), relerr.item())
2398+
# print(A.flatten()[-6:])
2399+
# print(B.flatten()[-6:])
2400+
# out = A.flatten()[-6:]*B.flatten()[-6:]
2401+
# print(out)
2402+
# print(out[:-1].sum())
2403+
# print('='*80)
2404+
# print(C1.flatten()[-6:])
2405+
# print(C2.flatten()[-6:])
2406+
# #assert False, 'ERROR'
2407+
2408+
c = int(C1.numel()*0.00125*(dim/256))+1
2409+
assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
2410+
print('')
2411+
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
2412+
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
2413+
print(dim, (max_err.item(), max_relerr.item()))
23992414

24002415
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
24012416
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])

0 commit comments

Comments
 (0)