Skip to content

Commit 1d99b61

Browse files
authored
[NVIDIA] Expose cublas.gemm (#7656)
Useful for test performance of a GEMM implementation.
1 parent 5a87bde commit 1d99b61

File tree

3 files changed

+111
-54
lines changed

3 files changed

+111
-54
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ def _matmul_launch_metadata(grid, kernel, args):
7676

7777
def matmul_get_configs(pre_hook=None):
7878
return [
79-
triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K" : BK, "GROUP_SIZE_M" : 8}, num_stages=s, num_warps=w, pre_hook=pre_hook) \
80-
for BM in [128] \
81-
for BN in [128, 256] \
82-
for BK in [64,128] \
83-
for s in ([3,4]) \
84-
for w in [4,8] \
79+
triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8}, num_stages=s,
80+
num_warps=w, pre_hook=pre_hook)
81+
for BM in [128]
82+
for BN in [128, 256]
83+
for BK in [64, 128]
84+
for s in ([2, 3, 4])
85+
for w in [4, 8]
8586
]
8687

8788

third_party/nvidia/include/cublas_instance.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ class CublasLtInstance {
120120
}
121121

122122
// Simple wrapper around the cublasLtMatmul function
123-
void matmul_impl(int m, int n, int k, uint64_t A, uint64_t B, uint64_t D,
124-
cudaDataType_t dtype) {
123+
void gemm_impl(int m, int n, int k, uint64_t A, uint64_t B, uint64_t C,
124+
uint64_t D, cudaDataType_t dtype, float alpha, float beta) {
125125
cublasLtMatmulDesc_t matmulDesc = NULL;
126126

127127
cublasOperation_t transa = CUBLAS_OP_T;
@@ -160,10 +160,8 @@ class CublasLtInstance {
160160
"No valid algorithm found by cublasLtMatmulAlgoGetHeuristic");
161161
}
162162

163-
float alpha = 1.0f;
164-
float beta = 0.0f;
165163
successOrExit(cublasLtMatmul(ltHandle, matmulDesc, &alpha, (void *)A, Adesc,
166-
(void *)B, Bdesc, &beta, nullptr, Cdesc,
164+
(void *)B, Bdesc, &beta, (void *)C, Cdesc,
167165
(void *)D, Ddesc, &heuristicResult.algo,
168166
(void *)workspace, workspaceSize, 0));
169167
if (Ddesc)
@@ -206,7 +204,12 @@ class CublasLtInstance {
206204
cudaDataType_t dtype) {
207205
// CUDA is column-major, while triton is row-major, therefore we need to
208206
// reverse the order of the matrices ( A * B = (B^T * A^T)^T ).
209-
matmul_impl(n, m, k, B, A, C, dtype);
207+
gemm_impl(n, m, k, B, A, 0, C, dtype, 1.0f, 0.0f);
208+
}
209+
210+
void gemm(int m, int n, int k, uint64_t A, uint64_t B, uint64_t C, uint64_t D,
211+
cudaDataType_t dtype, float alpha, float beta) {
212+
gemm_impl(n, m, k, B, A, C, D, dtype, alpha, beta);
210213
}
211214
};
212215

third_party/nvidia/triton_nvidia.cc

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,53 @@ void init_triton_hopper_passes(py::module &&m) {
8484
mlir::createNVGPUWarpSpecialization, int, bool);
8585
}
8686

87+
static void checkMatmulConstraints(const std::string &A_dtype,
88+
const std::string &B_dtype,
89+
const std::string &C_dtype,
90+
const std::vector<int> &A_shape,
91+
const std::vector<int> &B_shape,
92+
const std::vector<int> &C_shape) {
93+
if (A_dtype != B_dtype || A_dtype != C_dtype) {
94+
throw std::runtime_error("Data types do not match.");
95+
}
96+
if (A_dtype != "torch.float8_e4m3fn" && A_dtype != "torch.float16") {
97+
throw std::runtime_error("Unsupported data type.");
98+
}
99+
100+
if (A_shape.size() != 2 || B_shape.size() != 2 || C_shape.size() != 2) {
101+
throw std::runtime_error("Only 2D matrices are supported.");
102+
}
103+
104+
int k = A_shape[1];
105+
if (k != B_shape[1]) {
106+
throw std::runtime_error(
107+
"Matrix dimensions do not match. A is [" + std::to_string(A_shape[0]) +
108+
", " + std::to_string(A_shape[1]) + "], B is [" +
109+
std::to_string(B_shape[0]) + ", " + std::to_string(B_shape[1]) +
110+
"]. Expected A.shape[1] == B.shape[1]. Note "
111+
"that B needs to be transposed.");
112+
}
113+
114+
int m = A_shape[0];
115+
if (m != C_shape[0]) {
116+
throw std::runtime_error(
117+
"Matrix dimensions do not match. A is [" + std::to_string(A_shape[0]) +
118+
", " + std::to_string(A_shape[1]) + "], C is [" +
119+
std::to_string(C_shape[0]) + ", " + std::to_string(C_shape[1]) +
120+
"]. Expected A.shape[0] == C.shape[0].");
121+
}
122+
123+
int n = B_shape[0];
124+
if (n != C_shape[1]) {
125+
throw std::runtime_error(
126+
"Matrix dimensions do not match. B is [" + std::to_string(B_shape[0]) +
127+
", " + std::to_string(B_shape[1]) + "], C is [" +
128+
std::to_string(C_shape[0]) + ", " + std::to_string(C_shape[1]) +
129+
"]. Expected B.shape[0] == C.shape[1]. Note "
130+
"that B needs to be transposed.");
131+
}
132+
}
133+
87134
void init_triton_nvidia(py::module &&m) {
88135
auto passes = m.def_submodule("passes");
89136
init_triton_nvidia_passes_nvws(passes.def_submodule("nvws"));
@@ -155,22 +202,64 @@ void init_triton_nvidia(py::module &&m) {
155202
workspace.attr("element_size")().cast<size_t>();
156203
return new CublasLtInstance(wrk_ptr, wrk_size);
157204
}))
158-
.def("matmul", [](CublasLtInstance &self, py::object &A, py::object &B,
159-
py::object &C) {
205+
.def("matmul",
206+
[](CublasLtInstance &self, py::object &A, py::object &B,
207+
py::object &C) {
208+
auto A_ptr = A.attr("data_ptr")().cast<uint64_t>();
209+
auto B_ptr = B.attr("data_ptr")().cast<uint64_t>();
210+
auto C_ptr = C.attr("data_ptr")().cast<uint64_t>();
211+
212+
auto A_shape = A.attr("shape").cast<std::vector<int>>();
213+
auto B_shape = B.attr("shape").cast<std::vector<int>>();
214+
auto C_shape = C.attr("shape").cast<std::vector<int>>();
215+
216+
auto A_dtype =
217+
A.attr("dtype").attr("__str__")().cast<std::string>();
218+
auto B_dtype =
219+
B.attr("dtype").attr("__str__")().cast<std::string>();
220+
auto C_dtype =
221+
C.attr("dtype").attr("__str__")().cast<std::string>();
222+
223+
checkMatmulConstraints(A_dtype, B_dtype, C_dtype, A_shape, B_shape,
224+
C_shape);
225+
226+
std::string dtype_str =
227+
A_dtype.substr(A_dtype.find_last_of('.') + 1);
228+
cudaDataType_t dtype;
229+
if (dtype_str == "float8_e4m3fn") {
230+
dtype = CUDA_R_8F_E4M3;
231+
} else if (dtype_str == "float16") {
232+
dtype = CUDA_R_16F;
233+
}
234+
235+
self.matmul(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr,
236+
C_ptr, dtype);
237+
})
238+
.def("gemm", [](CublasLtInstance &self, py::object &A, py::object &B,
239+
py::object &C, py::object &D, float alpha, float beta) {
160240
auto A_ptr = A.attr("data_ptr")().cast<uint64_t>();
161241
auto B_ptr = B.attr("data_ptr")().cast<uint64_t>();
162242
auto C_ptr = C.attr("data_ptr")().cast<uint64_t>();
243+
auto D_ptr = D.attr("data_ptr")().cast<uint64_t>();
163244

164245
auto A_shape = A.attr("shape").cast<std::vector<int>>();
165246
auto B_shape = B.attr("shape").cast<std::vector<int>>();
166247
auto C_shape = C.attr("shape").cast<std::vector<int>>();
248+
auto D_shape = D.attr("shape").cast<std::vector<int>>();
167249

168250
auto A_dtype = A.attr("dtype").attr("__str__")().cast<std::string>();
169251
auto B_dtype = B.attr("dtype").attr("__str__")().cast<std::string>();
170252
auto C_dtype = C.attr("dtype").attr("__str__")().cast<std::string>();
253+
auto D_dtype = D.attr("dtype").attr("__str__")().cast<std::string>();
171254

172-
assert(A_dtype == B_dtype && A_dtype == C_dtype);
173-
assert(A_dtype == "torch.float8_e4m3fn" || A_dtype == "torch.float16");
255+
checkMatmulConstraints(A_dtype, B_dtype, D_dtype, A_shape, B_shape,
256+
D_shape);
257+
if (C_dtype != "torch.float16") {
258+
throw std::runtime_error("C dtype must be float16, got " + C_dtype);
259+
}
260+
if (C_shape != D_shape) {
261+
throw std::runtime_error("C and D shapes must match");
262+
}
174263

175264
std::string dtype_str = A_dtype.substr(A_dtype.find_last_of('.') + 1);
176265
cudaDataType_t dtype;
@@ -180,43 +269,7 @@ void init_triton_nvidia(py::module &&m) {
180269
dtype = CUDA_R_16F;
181270
}
182271

183-
if (A_shape.size() != 2 || B_shape.size() != 2 || C_shape.size() != 2) {
184-
throw std::runtime_error("Only 2D matrices are supported.");
185-
}
186-
187-
int k = A_shape[1];
188-
if (k != B_shape[1]) {
189-
throw std::runtime_error("Matrix dimensions do not match. A is [" +
190-
std::to_string(A_shape[0]) + ", " +
191-
std::to_string(A_shape[1]) + "], B is [" +
192-
std::to_string(B_shape[0]) + ", " +
193-
std::to_string(B_shape[1]) +
194-
"]. Expected A.shape[1] == B.shape[1]. Note "
195-
"that B needs to be transposed.");
196-
}
197-
198-
int m = A_shape[0];
199-
if (m != C_shape[0]) {
200-
throw std::runtime_error("Matrix dimensions do not match. A is [" +
201-
std::to_string(A_shape[0]) + ", " +
202-
std::to_string(A_shape[1]) + "], C is [" +
203-
std::to_string(C_shape[0]) + ", " +
204-
std::to_string(C_shape[1]) +
205-
"]. Expected A.shape[0] == C.shape[0].");
206-
}
207-
208-
int n = B_shape[0];
209-
if (n != C_shape[1]) {
210-
throw std::runtime_error("Matrix dimensions do not match. B is [" +
211-
std::to_string(B_shape[0]) + ", " +
212-
std::to_string(B_shape[1]) + "], C is [" +
213-
std::to_string(C_shape[0]) + ", " +
214-
std::to_string(C_shape[1]) +
215-
"]. Expected B.shape[0] == C.shape[1]. Note "
216-
"that B needs to be transposed.");
217-
}
218-
219-
self.matmul(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr, C_ptr,
220-
dtype);
272+
self.gemm(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr, C_ptr,
273+
D_ptr, dtype, alpha, beta);
221274
});
222275
}

0 commit comments

Comments
 (0)