-
Notifications
You must be signed in to change notification settings - Fork 13.5k
CUDA: Volta tensor core support for MMF #16843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,10 @@ | |
|
|
||
| #include "common.cuh" | ||
|
|
||
| // On Volta each warp is doing 4 8x8 mma operations in parallel. | ||
| // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile. | ||
| // However, the i indices in this file are by default permuted to simplify the index calculations. | ||
| // #define GGML_CUDA_MMA_NO_VOLTA_PERM | ||
|
|
||
| #if CUDART_VERSION >= 11080 | ||
|
|
||
|
|
@@ -86,6 +90,7 @@ namespace ggml_cuda_mma { | |
| return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -102,6 +107,32 @@ namespace ggml_cuda_mma { | |
| return threadIdx.x % 32; | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
| #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| static constexpr int ne = I * J / 32; | ||
| T x[ne] = {0}; | ||
|
|
||
| static __device__ __forceinline__ int get_i(const int l) { | ||
| if constexpr (I == 32 && J == 8) { | ||
| #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM | ||
| return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); | ||
| #else | ||
| return (l & 2) | (threadIdx.x & ~2); | ||
| #endif // GGML_CUDA_MMA_NO_VOLTA_PERM | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
| static __device__ __forceinline__ int get_j(const int l) { | ||
| if constexpr (I == 32 && J == 8) { | ||
| return (threadIdx.x & 2) | (l & (4 + 1)); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
| #else | ||
|
|
@@ -111,26 +142,28 @@ namespace ggml_cuda_mma { | |
| static __device__ __forceinline__ int get_i(const int l) { | ||
| if constexpr (I == 8 && (J == 4 || J == 8)) { | ||
| return threadIdx.x / 4; | ||
| } else if constexpr (I == 16 && J == 8) { | ||
| return (l / 2) * 8 + threadIdx.x / 4; | ||
| } else if constexpr ((I == 16 || I == 32) && J == 8) { | ||
| return ((l / 2) * 8) | (threadIdx.x / 4); | ||
| } else if constexpr (I == 16 && J == 16) { | ||
| return ((l / 2) % 2) * 8 + threadIdx.x / 4; | ||
| return (((l / 2) % 2) * 8) | (threadIdx.x / 4); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
| static __device__ __forceinline__ int get_j(const int l) { | ||
| if constexpr (I == 8 && J == 4) { | ||
| return threadIdx.x % 4; | ||
| } else if constexpr (I == 8 && J == 8) { | ||
| return 4 * l + threadIdx.x % 4; | ||
| } else if constexpr (I == 16 && J == 8) { | ||
| return 2 * (threadIdx.x % 4) + l % 2; | ||
| return (l * 4) | (threadIdx.x % 4); | ||
| } else if constexpr ((I == 16 || I == 32) && J == 8) { | ||
| return ((threadIdx.x % 4) * 2) | (l % 2); | ||
| } else if constexpr (I == 16 && J == 16) { | ||
| return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; | ||
| return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
| #endif // defined(GGML_USE_HIP) | ||
|
|
@@ -140,32 +173,68 @@ namespace ggml_cuda_mma { | |
| struct tile<I_, J_, half2> { | ||
| static constexpr int I = I_; | ||
| static constexpr int J = J_; | ||
|
|
||
| #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; | ||
| half2 x[ne] = {{0.0f, 0.0f}}; | ||
|
|
||
| static __device__ __forceinline__ int get_i(const int l) { | ||
| if constexpr (I == 8 && J == 8) { | ||
| return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); | ||
| } else if constexpr (I == 32 && J == 8) { | ||
| #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM | ||
| return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); | ||
| #else | ||
| return threadIdx.x; | ||
| #endif // GGML_CUDA_MMA_NO_VOLTA_PERM | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
| static __device__ __forceinline__ int get_j(const int l) { | ||
| if constexpr ((I == 8 || I == 32) && J == 8) { | ||
| return l; | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
| #else | ||
| static constexpr int ne = I * J / WARP_SIZE; | ||
| half2 x[ne] = {{0.0f, 0.0f}}; | ||
|
|
||
| static __device__ __forceinline__ int get_i(const int l) { | ||
| if constexpr (I == 8 && J == 8) { | ||
| return threadIdx.x / 4; | ||
| } else if constexpr (I == 16 && J == 4) { | ||
| return l * 8 + threadIdx.x / 4; | ||
| return (l * 8) | (threadIdx.x / 4); | ||
| } else if constexpr (I == 16 && J == 8) { | ||
| return (l % 2) * 8 + threadIdx.x / 4; | ||
| return ((l % 2) * 8) | (threadIdx.x / 4); | ||
| } else if constexpr (I == 32 && J == 8) { | ||
| return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
| static __device__ __forceinline__ int get_j(const int l) { | ||
| if constexpr (I == 8 && J == 8) { | ||
| return l * 4 + threadIdx.x % 4; | ||
| return (l * 4) | (threadIdx.x % 4); | ||
| } else if constexpr (I == 16 && J == 4) { | ||
| return threadIdx.x % 4; | ||
| } else if constexpr (I == 16 && J == 8) { | ||
| return (l / 2) * 4 + threadIdx.x % 4; | ||
| return ((l / 2) * 4) | (threadIdx.x % 4); | ||
| } else if constexpr (I == 32 && J == 8) { | ||
| return ((l & 2) * 2) | (threadIdx.x % 4); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
| #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| }; | ||
|
|
||
| template <int I_, int J_> | ||
|
|
@@ -179,23 +248,25 @@ namespace ggml_cuda_mma { | |
| if constexpr (I == 8 && J == 8) { | ||
| return threadIdx.x / 4; | ||
| } else if constexpr (I == 16 && J == 4) { | ||
| return l * 8 + threadIdx.x / 4; | ||
| return (l * 8) | (threadIdx.x / 4); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this more performant? I find the earlier version easier to read There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a good thing you're asking that because I misremembered the table from the CUDA documentation showing instruction throughput. The way I remembered it integer additions and binary operations had the same throughput but on a silicon level you would have lower power draw. In actually though the throughput of additions is twice that of binary operations. |
||
| } else if constexpr (I == 16 && J == 8) { | ||
| return (l % 2) * 8 + threadIdx.x / 4; | ||
| return ((l % 2) * 8) | (threadIdx.x / 4); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
|
|
||
| static __device__ __forceinline__ int get_j(const int l) { | ||
| if constexpr (I == 8 && J == 8) { | ||
| return l * 4 + threadIdx.x % 4; | ||
| return (l * 4) | (threadIdx.x % 4); | ||
| } else if constexpr (I == 16 && J == 4) { | ||
| return threadIdx.x % 4; | ||
| } else if constexpr (I == 16 && J == 8) { | ||
| return (l / 2) * 4 + threadIdx.x % 4; | ||
| return ((l / 2) * 4) | (threadIdx.x % 4); | ||
| } else { | ||
| static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||
| return -1; | ||
| } | ||
| } | ||
| }; | ||
|
|
@@ -263,8 +334,12 @@ namespace ggml_cuda_mma { | |
| : "=r"(xi[0]), "=r"(xi[1]) | ||
| : "l"(xs)); | ||
| #else | ||
| load_generic(xs0, stride); | ||
| GGML_UNUSED(t); | ||
| #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| GGML_UNUSED_VARS(t, xs0, stride); | ||
| NO_DEVICE_CODE; | ||
| #else | ||
| load_generic(t, xs0, stride); | ||
| #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| #endif // TURING_MMA_AVAILABLE | ||
| } | ||
|
|
||
|
|
@@ -277,11 +352,35 @@ namespace ggml_cuda_mma { | |
| asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" | ||
| : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) | ||
| : "l"(xs)); | ||
| #else | ||
| #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| GGML_UNUSED_VARS(t, xs0, stride); | ||
| NO_DEVICE_CODE; | ||
| #else | ||
| load_generic(t, xs0, stride); | ||
| #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| #endif // TURING_MMA_AVAILABLE | ||
| } | ||
|
|
||
| template <typename T> | ||
| static __device__ __forceinline__ void load_ldmatrix( | ||
| tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { | ||
| #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| #if 1 | ||
| // TODO: more generic handling | ||
| static_assert(sizeof(T) == 4, "bad type size"); | ||
| ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); | ||
| ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); | ||
| #else | ||
| load_generic(t, xs0, stride); | ||
| #endif // 1 | ||
| #else | ||
| tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; | ||
| load_ldmatrix(t16[0], xs0 + 0*stride, stride); | ||
| load_ldmatrix(t16[1], xs0 + 16*stride, stride); | ||
| #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| } | ||
|
|
||
| template <typename T> | ||
| static __device__ __forceinline__ void load_ldmatrix_trans( | ||
| tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { | ||
|
|
@@ -546,4 +645,43 @@ namespace ggml_cuda_mma { | |
| NO_DEVICE_CODE; | ||
| #endif // AMD_MFMA_AVAILABLE | ||
| } | ||
|
|
||
| template <typename T1, typename T2, int J, int K> | ||
| static __device__ __forceinline__ void mma( | ||
| tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) { | ||
| tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; | ||
| tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; | ||
| mma(D16[0], A16[0], B); | ||
| mma(D16[1], A16[1], B); | ||
| } | ||
|
|
||
| static __device__ __forceinline__ void mma( | ||
| tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { | ||
| #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||
| const int * Axi = (const int *) A.x; | ||
| const int * Bxi = (const int *) B.x; | ||
| int * Dxi = (int *) D.x; | ||
| asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " | ||
| "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" | ||
| : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||
| : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); | ||
| asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " | ||
| "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" | ||
| : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||
| : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); | ||
| asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " | ||
| "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" | ||
| : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||
| : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); | ||
| asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " | ||
| "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" | ||
| : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||
| : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); | ||
| #else | ||
| tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; | ||
| tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; | ||
| mma(D16[0], A16[0], B); | ||
| mma(D16[1], A16[1], B); | ||
| #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.