|
16 | 16 | #include <thrust/device_vector.h> |
17 | 17 | #include <mma.h> |
18 | 18 |
|
19 | | -#include <cooperative_groups/memcpy_async.h> |
20 | | -#include <cuda/pipeline> |
21 | 19 |
|
22 | 20 | #define HLF_MAX 65504 |
23 | 21 | #define TH 1024 |
24 | 22 | #define NUM 4 |
25 | 23 | #define NUM_BLOCK 4096 |
26 | 24 |
|
27 | | -using namespace nvcuda; |
28 | 25 |
|
29 | 26 | // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda |
30 | 27 | __device__ float atomicMax(float* address, float val) { |
@@ -3094,6 +3091,9 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l |
3094 | 3091 | #define WARPS 5 |
3095 | 3092 | template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) |
3096 | 3093 | { |
| 3094 | + |
| 3095 | +#if __CUDA_ARCH__ >= 750 |
| 3096 | + using namespace nvcuda; |
3097 | 3097 | int col_offset = blockIdx.x *32; |
3098 | 3098 | const int warp_id = threadIdx.x / 32; |
3099 | 3099 | const int half_warp_id = threadIdx.x / 16; |
@@ -3294,11 +3294,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, |
3294 | 3294 |
|
3295 | 3295 | if(col_offset + warp_lane < M) |
3296 | 3296 | out[col_offset + warp_lane] = smem_A[warp_lane]; |
| 3297 | +#endif |
3297 | 3298 | } |
3298 | 3299 |
|
3299 | 3300 | template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) |
3300 | 3301 | { |
3301 | 3302 |
|
| 3303 | +#if __CUDA_ARCH__ >= 750 |
| 3304 | + using namespace nvcuda; |
3302 | 3305 | int col_offset = blockIdx.x *32; |
3303 | 3306 | const int warp_id = threadIdx.x / 32; |
3304 | 3307 | const int half_warp_id = threadIdx.x / 16; |
@@ -3459,6 +3462,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i |
3459 | 3462 |
|
3460 | 3463 | if(col_offset + warp_lane < M) |
3461 | 3464 | out[col_offset + warp_lane] = smem_A[warp_lane]; |
| 3465 | +#endif |
3462 | 3466 | } |
3463 | 3467 |
|
3464 | 3468 | //#define ROWS 2 |
|
0 commit comments