Skip to content

Commit 869b7e8

Browse files
committed
Warp multi-specialization 240.
1 parent 77f15fd commit 869b7e8

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

csrc/kernels.cu

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30583058
const int half_warp_lane = threadIdx.x % 16;
30593059
const int batch_size_warps = (WARPS-1)*2;
30603060

3061-
T local_A[1];
3062-
T local_B[32];
3061+
T local_A[2];
3062+
T local_B[64];
30633063

30643064
const int a_tile_offset = 16;
30653065
const int b_tile_offset = (16*32 + 16);
@@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
30753075

30763076
int ticktock = 0;
30773077
int idx = 0 + threadIdx.x;
3078+
int loaded_values = 0;
30783079
// prefetch
30793080
if(idx < K && warp_id < (WARPS-1))
30803081
{
3081-
local_A[0] = A[idx];
3082+
if(loaded_values == 0)
3083+
{
3084+
local_A[0] = A[idx];
3085+
local_A[1] = A[idx+blockDim.x-32];
30823086

3083-
#pragma unroll 32
3084-
for(int col = 0; col < 32; col++)
3085-
local_B[col] = B[(col_offset+col)*ldb+idx];
3087+
#pragma unroll 32
3088+
for(int col = 0; col < 32; col++)
3089+
{
3090+
local_B[col] = B[(col_offset+col)*ldb+idx];
3091+
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
3092+
}
3093+
loaded_values = 1;
3094+
}
3095+
else
3096+
{
3097+
local_A[0] = local_A[1];
3098+
loaded_values--;
3099+
3100+
#pragma unroll 32
3101+
for(int col = 0; col < 32; col++)
3102+
local_B[col] = local_B[col+32];
3103+
}
30863104

30873105
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
30883106

@@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
31133131
__syncthreads();
31143132
if(idx < K && warp_id < (WARPS-1))
31153133
{
3116-
local_A[0] = A[idx];
3134+
//local_A[0] = A[idx];
31173135

3118-
#pragma unroll 32
3119-
for(int col = 0; col < 32; col++)
3120-
local_B[col] = B[(col_offset+col)*ldb+idx];
3136+
//#pragma unroll 32
3137+
//for(int col = 0; col < 32; col++)
3138+
// local_B[col] = B[(col_offset+col)*ldb+idx];
3139+
if(loaded_values == 0)
3140+
{
3141+
local_A[0] = A[idx];
3142+
local_A[1] = A[idx+blockDim.x-32];
3143+
3144+
#pragma unroll 32
3145+
for(int col = 0; col < 32; col++)
3146+
{
3147+
local_B[col] = B[(col_offset+col)*ldb+idx];
3148+
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
3149+
}
3150+
loaded_values = 1;
3151+
}
3152+
else
3153+
{
3154+
local_A[0] = local_A[1];
3155+
loaded_values--;
3156+
3157+
#pragma unroll 32
3158+
for(int col = 0; col < 32; col++)
3159+
local_B[col] = local_B[col+32];
3160+
3161+
3162+
}
31213163

31223164
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
31233165

tests/test_functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
23762376
#print('')
23772377
#print(A)
23782378
#print(B.t())
2379-
#A[:, :-3] = 0
2380-
#B[:, :-3] = 0
2379+
#A[:, :-1] = 0
2380+
#B[:, :-1] = 0
23812381

23822382

23832383
C1 = torch.matmul(A, B.t())
@@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
23992399

24002400
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
24012401
# print('')
2402-
# print(i, err, mag.item(), relerr.item())
2402+
# print(i, err, relerr)
24032403
# print(A.flatten()[-6:])
24042404
# print(B.flatten()[-6:])
24052405
# out = A.flatten()[-6:]*B.flatten()[-6:]
@@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
24122412

24132413
c = int(C1.numel()*0.0014*(dim/256))+1
24142414

2415-
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
2415+
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
24162416
#print(c/math.sqrt(dim))
24172417
print('')
24182418
print(dim, sum(errs)/len(errs)/math.sqrt(dim))

0 commit comments

Comments
 (0)