|
| 1 | +#include <cuda_runtime.h> |
| 2 | +#include <cuda_fp16.h> |
| 3 | +#include <stdexcept> |
| 4 | +#include <algorithm> |
| 5 | + |
| 6 | +constexpr int WARP_SIZE = 64; |
| 7 | + |
| 8 | +template <typename T> |
| 9 | +__device__ __forceinline__ T silu(const T& x) { |
| 10 | + // x * sigmoid(x) |
| 11 | + return (T)(((float)x) / (1.0f + expf((float)-x))); |
| 12 | +} |
| 13 | + |
| 14 | +template <typename T> |
| 15 | +__device__ __forceinline__ T loadnt(T* addr) { |
| 16 | + return __builtin_nontemporal_load(addr); |
| 17 | +} |
| 18 | + |
| 19 | +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { |
| 20 | + auto addr_alias = reinterpret_cast<const float*>(addr); |
| 21 | + auto dat0 = loadnt(addr_alias); |
| 22 | + auto dat1 = loadnt(addr_alias + 1); |
| 23 | + auto dat2 = loadnt(addr_alias + 2); |
| 24 | + auto dat3 = loadnt(addr_alias + 3); |
| 25 | + // auto dat0 = *(addr_alias); |
| 26 | + // auto dat1 = *(addr_alias+1); |
| 27 | + // auto dat2 = *(addr_alias+2); |
| 28 | + // auto dat3 = *(addr_alias+3); |
| 29 | + return make_float4(dat0, dat1, dat2, dat3); |
| 30 | +} |
| 31 | + |
| 32 | +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume |
| 33 | +// N=1 for time being grid is M/A_NUM_ROWS blocks |
| 34 | +template <int NUM_A_ROWS_PER_BLOCK> |
| 35 | +__global__ void LLGemm_Silu_kernel(float4* af4, __half2* bf4, _Float16* c, |
| 36 | + const int d) { |
| 37 | + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; |
| 38 | + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x; |
| 39 | + const int row_addr_d = row_addr + d * blockDim.x; |
| 40 | + // int row_addr_1 = row_addr + CUDA_NUM_THREADS; |
| 41 | + // int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; |
| 42 | + // int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; |
| 43 | + const int threadid = threadIdx.x; |
| 44 | + const int warp = threadIdx.x / WARP_SIZE; |
| 45 | + const int lane = threadIdx.x % WARP_SIZE; |
| 46 | + const int num_warps = blockDim.x / WARP_SIZE; |
| 47 | + const int qwarpid = threadid / 16; |
| 48 | + const int qthreadid = threadid % 16; |
| 49 | + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; |
| 50 | + // float4 colB_elem4; |
| 51 | + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; |
| 52 | + float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; |
| 53 | + __half2 acch2; |
| 54 | + |
| 55 | + // rowA_elem4 = af4[row_addr + threadid]; |
| 56 | + //__syncthreads(); |
| 57 | + // rowA_elem4_1 = af4[row_addr_1 + threadid]; |
| 58 | + // rowA_elem4_2 = af4[row_addr_2 + threadid]; |
| 59 | + // rowA_elem4_3 = af4[row_addr_3 + threadid]; |
| 60 | +#pragma unroll |
| 61 | + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) { |
| 62 | + rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]); |
| 63 | + rowA_elem4[2 * i + 1] = |
| 64 | + load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]); |
| 65 | + // rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid]; |
| 66 | + //__syncthreads(); |
| 67 | + } |
| 68 | + colB_elem4x = bf4[threadid * 4 + 0]; |
| 69 | + colB_elem4y = bf4[threadid * 4 + 1]; |
| 70 | + colB_elem4z = bf4[threadid * 4 + 2]; |
| 71 | + colB_elem4w = bf4[threadid * 4 + 3]; |
| 72 | + |
| 73 | + // __syncthreads(); |
| 74 | + __half2 Af2; |
| 75 | + float2 S; |
| 76 | + // auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4); |
| 77 | + // auto Bf2x = *Bh2ptr; |
| 78 | + // auto Bf2y = *(Bh2ptr+1); |
| 79 | + // auto Bf2z = *(Bh2ptr+2); |
| 80 | + // auto Bf2w = *(Bh2ptr+3); |
| 81 | + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); |
| 82 | + __half2* ah2lptr; |
| 83 | +#pragma unroll |
| 84 | + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { |
| 85 | + ah2lptr = Ah2ptr + i * 4; |
| 86 | + Af2 = *(ah2lptr); |
| 87 | + acch2 = __hmul2(Af2, colB_elem4x); |
| 88 | + Af2 = *(ah2lptr + 1); |
| 89 | + acch2 = __hfma2(Af2, colB_elem4y, acch2); |
| 90 | + Af2 = *(ah2lptr + 2); |
| 91 | + acch2 = __hfma2(Af2, colB_elem4z, acch2); |
| 92 | + Af2 = *(ah2lptr + 3); |
| 93 | + acch2 = __hfma2(Af2, colB_elem4w, acch2); |
| 94 | + S = __half22float2(acch2); |
| 95 | + acc[i] = S.x + S.y; |
| 96 | + } |
| 97 | + |
| 98 | +#pragma unroll |
| 99 | + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { |
| 100 | +#pragma unroll |
| 101 | + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { |
| 102 | + acc[i] += __shfl_xor(acc[i], mask); |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + // Warp leaders store the data to shared memory. |
| 107 | + // if (lane == 0) { |
| 108 | + // #pragma unroll |
| 109 | + // for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) { |
| 110 | + // red_smem[i][warp] = acc[i]; |
| 111 | + // } |
| 112 | + // } |
| 113 | + |
| 114 | + if (lane < NUM_A_ROWS_PER_BLOCK) { |
| 115 | + red_smem[lane][warp] = acc[lane]; |
| 116 | + } |
| 117 | + |
| 118 | + // Make sure the data is in shared memory. |
| 119 | + __syncthreads(); |
| 120 | + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { |
| 121 | + // if (threadid<64) { |
| 122 | + // #pragma unroll |
| 123 | + // for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) { |
| 124 | + // acc[i+2*qwarpid] = 0.0; |
| 125 | + // } |
| 126 | + ////acc[qwarpid] = 0.0; |
| 127 | + |
| 128 | + ////if (qthreadid<num_warps) { |
| 129 | + // #pragma unroll |
| 130 | + // for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) { |
| 131 | + // acc[i+2*qwarpid] = red_smem[i+2*qwarpid][qthreadid]; |
| 132 | + // } |
| 133 | + ////acc[qwarpid] = red_smem[qwarpid][qthreadid]; |
| 134 | + |
| 135 | + ////} |
| 136 | + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; |
| 137 | + // if (threadid<32) { |
| 138 | +#pragma unroll |
| 139 | + for (int mask = 16 / 2; mask >= 1; mask /= 2) { |
| 140 | + // #pragma unroll |
| 141 | + // for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) { |
| 142 | + // acc[i+2*qwarpid] += __shfl_xor(acc[i+2*qwarpid], mask); |
| 143 | + // } |
| 144 | + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); |
| 145 | + } |
| 146 | + float oval2 = __shfl_xor(acc[qwarpid], 16); |
| 147 | + // acc[1] = __shfl_xor(acc[1],16); |
| 148 | + // acc[3] = __shfl_xor(acc[3],16); |
| 149 | + //} |
| 150 | + // __syncthreads(); |
| 151 | + // if (threadid < NUM_A_ROWS_PER_BLOCK/2) { |
| 152 | + if (lane == 0 or lane == 32) { |
| 153 | + // oval = __float22half2_rn(make_float2(acc[qwarpid],oval2)); |
| 154 | + // c[blockIdx.x*NUM_A_ROWS_PER_BLOCK/2+qwarpid/2] = oval; |
| 155 | + |
| 156 | + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = |
| 157 | + silu(acc[qwarpid]) * oval2; |
| 158 | + } |
| 159 | + } // threadid<WARP_SIZE |
| 160 | +} |
| 161 | +// define the kernel calling code: |
| 162 | +// template <typename T> |
| 163 | +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, |
| 164 | + cudaStream_t stream, const int rows_per_block = 4) { |
| 165 | + float4* af4 = reinterpret_cast<float4*>(in_a); |
| 166 | + auto* bf4 = reinterpret_cast<__half2*>(in_b); |
| 167 | + auto* c = reinterpret_cast<_Float16*>(out_c); |
| 168 | + const int d = M / 2; |
| 169 | + const int NUM_THREADS = K * 2 / 16; |
| 170 | + int NUM_BLOCKS = M / rows_per_block; |
| 171 | + if (rows_per_block == 2) { |
| 172 | + LLGemm_Silu_kernel<2> |
| 173 | + <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d); |
| 174 | + } else if (rows_per_block == 4) { |
| 175 | + LLGemm_Silu_kernel<4> |
| 176 | + <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d); |
| 177 | + } else if (rows_per_block == 8) { |
| 178 | + LLGemm_Silu_kernel<8> |
| 179 | + <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d); |
| 180 | + } else if (rows_per_block == 16) { |
| 181 | + LLGemm_Silu_kernel<16> |
| 182 | + <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d); |
| 183 | + } else { |
| 184 | + NUM_BLOCKS = M / 4; |
| 185 | + LLGemm_Silu_kernel<4> |
| 186 | + <<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d); |
| 187 | + } |
| 188 | + |
| 189 | + cudaError_t err = cudaGetLastError(); |
| 190 | + if (cudaSuccess != err) |
| 191 | + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); |
| 192 | +} |
0 commit comments