Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA

#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand Down Expand Up @@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) {
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
}

static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
Expand Down
172 changes: 155 additions & 17 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

#include "common.cuh"

// On Volta each warp is doing 4 8x8 mma operations in parallel.
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
// However, the i indices in this file are by default permuted to simplify the index calculations.
// #define GGML_CUDA_MMA_NO_VOLTA_PERM

#if CUDART_VERSION >= 11080

Expand Down Expand Up @@ -86,6 +90,7 @@ namespace ggml_cuda_mma {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}

Expand All @@ -102,6 +107,32 @@ namespace ggml_cuda_mma {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 32 && J == 8) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
#else
return (l & 2) | (threadIdx.x & ~2);
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 32 && J == 8) {
return (threadIdx.x & 2) | (l & (4 + 1));
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}
#else
Expand All @@ -111,26 +142,28 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else if constexpr ((I == 16 || I == 32) && J == 8) {
return ((l / 2) * 8) | (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 16) {
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
return (l * 4) | (threadIdx.x % 4);
} else if constexpr ((I == 16 || I == 32) && J == 8) {
return ((threadIdx.x % 4) * 2) | (l % 2);
} else if constexpr (I == 16 && J == 16) {
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}
#endif // defined(GGML_USE_HIP)
Expand All @@ -140,32 +173,68 @@ namespace ggml_cuda_mma {
struct tile<I_, J_, half2> {
static constexpr int I = I_;
static constexpr int J = J_;

#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
} else if constexpr (I == 32 && J == 8) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
#else
return threadIdx.x;
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr ((I == 8 || I == 32) && J == 8) {
return l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}
#else
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
return (l * 8) | (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
return ((l % 2) * 8) | (threadIdx.x / 4);
} else if constexpr (I == 32 && J == 8) {
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
return (l * 4) | (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
return ((l / 2) * 4) | (threadIdx.x % 4);
} else if constexpr (I == 32 && J == 8) {
return ((l & 2) * 2) | (threadIdx.x % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
};

template <int I_, int J_>
Expand All @@ -179,23 +248,25 @@ namespace ggml_cuda_mma {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
return (l * 8) | (threadIdx.x / 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this more performant? I find the earlier version easier to read

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good thing you're asking that because I misremembered the table from the CUDA documentation showing instruction throughput. The way I remembered it integer additions and binary operations had the same throughput but on a silicon level you would have lower power draw. In actually though the throughput of additions is twice that of binary operations.

} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
return ((l % 2) * 8) | (threadIdx.x / 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
return (l * 4) | (threadIdx.x % 4);
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
return ((l / 2) * 4) | (threadIdx.x % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
return -1;
}
}
};
Expand Down Expand Up @@ -263,8 +334,12 @@ namespace ggml_cuda_mma {
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE;
#else
load_generic(t, xs0, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // TURING_MMA_AVAILABLE
}

Expand All @@ -277,11 +352,35 @@ namespace ggml_cuda_mma {
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs));
#else
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
GGML_UNUSED_VARS(t, xs0, stride);
NO_DEVICE_CODE;
#else
load_generic(t, xs0, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // TURING_MMA_AVAILABLE
}

template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if 1
// TODO: more generic handling
static_assert(sizeof(T) == 4, "bad type size");
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
#else
load_generic(t, xs0, stride);
#endif // 1
#else
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}

template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
Expand Down Expand Up @@ -546,4 +645,43 @@ namespace ggml_cuda_mma {
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}

template <typename T1, typename T2, int J, int K>
static __device__ __forceinline__ void mma(
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
}

static __device__ __forceinline__ void mma(
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}
}
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return turing_mma_available(cc);
return volta_mma_available(cc) || turing_mma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
default:
Expand Down
Loading
Loading