From c092bdc14be7f93beb03f1514a693b809d920933 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Tue, 25 Nov 2025 13:50:28 -0800 Subject: [PATCH] Cherry-picked commit with merge conflict --- aten/src/ATen/native/cuda/Blas.cpp | 59 +++ aten/src/ATen/native/hip/ck_group_gemm.h | 19 + aten/src/ATen/native/hip/ck_group_gemm.hip | 462 ++++++++++++++++++++ test/test_matmul_cuda.py | 470 +++++++++++++++++++++ 4 files changed, 1010 insertions(+) create mode 100644 aten/src/ATen/native/hip/ck_group_gemm.h create mode 100644 aten/src/ATen/native/hip/ck_group_gemm.hip diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 3bfa1859155ff..947c3fc98d057 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -18,6 +18,18 @@ #include #include #include +<<<<<<< HEAD +======= +#include +#ifdef USE_ROCM +#include +#endif +#include + +#ifdef USE_FBGEMM_GENAI +#include +#endif +>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829)) #ifndef AT_PER_OPERATOR_HEADERS #include @@ -946,9 +958,12 @@ static bool _scaled_mm_allowed_device() { return true; } } +<<<<<<< HEAD return false; #else return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); +======= +>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829)) #endif } @@ -1557,9 +1572,53 @@ bool use_fast_accum) { +<<<<<<< HEAD #else TORCH_CHECK(false, "grouped gemm is not supported on ROCM") #endif +======= +Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, +const std::optional& offs, +const std::optional& bias, +std::optional out_dtype) { + _grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype); + bool a_b_and_out_are_bf16 = ( + mat_a.dtype() == at::kBFloat16 && + mat_b.dtype() == at::kBFloat16 && + out_dtype.value_or(at::kBFloat16) == at::kBFloat16 + ); +#ifndef USE_ROCM + bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16; +#else + // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used. + // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm + bool use_fast_path = false; + // On non CK system(w/ ROCm), make sure use_fast_path is false +#if defined(USE_ROCM_CK_GEMM) + if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) { + use_fast_path = true; + } +#endif //USE_ROCM_CK_GEMM +#endif + const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); + Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); + if (use_fast_path) { + // fast path, no d2h sync needed +#ifndef USE_ROCM + at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); +#else +#if defined(USE_ROCM_CK_GEMM) + at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); +#else + TORCH_WARN("ROCm: Group Gemm through CK not selected."); +#endif //USE_ROCM_CK_GEMM +#endif + } else { + _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); + } + return out; +} +>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829)) } diff --git a/aten/src/ATen/native/hip/ck_group_gemm.h b/aten/src/ATen/native/hip/ck_group_gemm.h new file mode 100644 index 0000000000000..c50307c9f8ea3 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace hip { +namespace detail { +void group_gemm_ck( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + const std::optional& bias, + at::Tensor& out); + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/aten/src/ATen/native/hip/ck_group_gemm.hip b/aten/src/ATen/native/hip/ck_group_gemm.hip new file mode 100644 index 0000000000000..c436ad660c1c7 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.hip @@ -0,0 +1,462 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at { +namespace hip { +namespace detail { + +namespace CkTypes { + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using F32 = float; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; +} + +template +using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< + ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, + DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType, + CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + 1, 1, + S<1,32,1,8>, 4 +>; + +template +void launch_grouped_bgemm_ck_impl_dispatch( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + at::Tensor& out) +{ + using DeviceOp = GroupedGemmKernel; + using PassThrough = CkTypes::PassThrough; + + std::vector gemm_descs; + std::vector p_a_ptrs, p_b_ptrs; + std::vector p_e_ptrs; + // Note: d_ptrs will be resized after we populate the other vectors + + const int mat_a_dim = mat_a.dim(); + const int mat_b_dim = mat_b.dim(); + + const char* a_ptr_base = reinterpret_cast(mat_a.data_ptr()); + const char* b_ptr_base = reinterpret_cast(mat_b.data_ptr()); + char* out_ptr_base = reinterpret_cast(out.data_ptr()); + const size_t a_element_size = mat_a.element_size(); + const size_t b_element_size = mat_b.element_size(); + const size_t out_element_size = out.element_size(); + + // for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses. + if (mat_a_dim == 2 && mat_b_dim == 2) { + // 2D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + const int M = mat_a.size(0); // number of rows in A + const int N = mat_b.size(1); // number of columns in B + const int K = mat_a.size(1); // columns in A == rows in B + // for 2d*2d input, output is 3d. + // for each group, A columns (K) are sliced. M and N dimensions are not sliced. + for (int i = 0; i < num_groups; ++i) { + int start_k = (i == 0) ? 0 : offs_accessor[i-1]; + int end_k = offs_accessor[i]; + int k = end_k - start_k; + + //K dimension are sliced, hence select stride(1) always. + //K dimension is always dimension 1, regardless of memory layout (row/column major) + const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size; + const void* group_b_ptr; + int ldb; + + if (std::is_same::value) { + // Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size; + // Leading dimension = distance between rows = stride(0) + ldb = mat_b.stride(0); + } else { + // Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size; + // Leading dimension = distance between columns = stride(1) + ldb = mat_b.stride(1); + } + + // Calculate output pointer for group i in 3D tensor [num_groups, M, N] + // stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + int lda, ldc; + if (std::is_same::value) { + // Row-major A [M,K]: leading dimension = distance between rows = stride(0) + lda = mat_a.stride(0); + } else { + // Column-major A [M,K]: leading dimension = distance between columns = stride(1) + lda = mat_a.stride(1); + } + // Output is always row-major in 3D tensor [num_groups, M, N] + // Leading dimension for each group's [M,N] slice = stride(1) = N + ldc = out.stride(1); + size_t output_group_bytes = M * N * out_element_size; + void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes; + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(k), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 2 && mat_b_dim == 3) { + // 2D*3D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + + // 2d*3d input, output is 2d. + // A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n] + // Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B + const int K = mat_a.size(1); // columns in A + // For 2D-3D case: The output determines N (result width) + const int N = out.size(1); // N is the width of the output tensor + + for (int i = 0; i < num_groups; ++i) { + int start_m = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_m = offs_accessor[i]; + int m = end_m - start_m; + + // Skip zero-sized groups but continue processing subsequent groups + if (m <= 0) { + continue; + } + + // Select A rows for group i: skip start_m rows + const void* group_a_ptr; + int lda; + if (std::is_same::value) { + // Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + lda = mat_a.stride(0); // distance between rows + } else { + // Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows) + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + + // Detect stride pattern for A tensor to determine appropriate lda calculation + bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0)); + + if (a_is_strided_tensor) { + // For strided A tensors: stride(0) gives the actual leading dimension + lda = mat_a.stride(0); + } else { + // For non-strided A tensors: use the M dimension (total rows) + lda = mat_a.size(0); // Total M dimension for column-major layout + } + } + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + int ldb; + + if (std::is_same::value) { + // Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed + ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N] + } else { + // Detect stride pattern to determine appropriate ldb calculation + bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2)); + + if (is_strided_tensor) { + // For strided tensors: stride(2) gives the actual leading dimension + ldb = mat_b.stride(2); + } else { + // For non-strided tensors: use the N dimension + ldb = mat_b.size(1); + } + } + + // Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N] + void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size; + int ldc = out.stride(0); // distance between rows in output (should be N for 2D case) + + gemm_descs.push_back({ + static_cast(m), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 3) { + // 3d*3d input, output is 3d - batched matrix multiplication + // A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n] + // Each batch is processed as a separate GEMM operation + const int batch_size = mat_a.size(0); + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed) + + // Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout + int N; + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + N = mat_b.size(2); + } else if (mat_b.size(2) == K) { + // B is [batch, n, k] - transposed layout + N = mat_b.size(1); + } else { + TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[", + batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]"); + } + + for (int i = 0; i < batch_size; ++i) { + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + + // Select output batch for group i: Output[i, :, :] + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + + int lda, ldb, ldc; + + if (std::is_same::value) { + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + } else { + // Column-major A: leading dimension = distance between columns = stride(2) + lda = mat_a.stride(2); + } + + if (std::is_same::value) { + // Row-major B: leading dimension = distance between rows + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(1); // stride between K rows + } else { + // B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM + ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n]) + } + } else { + // Column-major B: leading dimension = distance between columns + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(2); // stride between N columns + } else { + // B is [batch, n, k] - transposed layout + ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n]) + } + } + + // Output is typically row-major: leading dimension = distance between rows = stride(1) + ldc = out.stride(1); + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 2) { + // 3D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + // 3d*2d input, output is 3d. + // A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both) + // Offset divides N dimension of B, each group gets different slice of B and different batch of A + const int batch_size = mat_a.size(0); // n_groups + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A + + // For row-major A and B case: B should be [K, total_N] + const int total_N = mat_b.size(1); // B is [K, total_N] for row-major + + for (int i = 0; i < num_groups; ++i) { + int start_n = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_n = offs_accessor[i]; + int n = end_n - start_n; + + // Skip zero-sized groups but continue processing subsequent groups + if (n <= 0) { + continue; + } + + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B slice for group i: B[:, start_n:end_n] (B[K, total_N]) + const void* group_b_ptr; + int ldb; + + // Check if B is row-major or column-major + if (std::is_same::value) { + // Row-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(0); // distance between rows (should be total_N) + } else { + // Column-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(1); // distance between columns (should be K) + } + + // Select output slice for group i: Output[:, start_n:end_n] + void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size; + + int lda, ldc; + + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + // Output is row-major: leading dimension = distance between rows = stride(0) + ldc = out.stride(0); + + gemm_descs.push_back({ + static_cast(M), + static_cast(n), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim); + } + + TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups"); + + // Initialize d_ptrs with the correct size + std::vector> d_ptrs(p_a_ptrs.size()); + + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), + "CK Group GEMM: argument unsupported (shape/strides/type config)"); + size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); + size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); + + void* gemm_arg_buf = nullptr; + void* ws_buf = nullptr; + + hipMalloc(&gemm_arg_buf, arg_buf_size); + hipMalloc(&ws_buf, ws_size); + + gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); + gemm_instance.SetWorkSpacePointer(&argument, ws_buf); + + auto invoker = gemm_instance.MakeInvoker(); + hipStream_t stream = c10::hip::getCurrentHIPStream(); + invoker.Run(argument, {stream}); + hipFree(gemm_arg_buf); + hipFree(ws_buf); +} + +void group_gemm_ck( + const at::Tensor& input_a, + const at::Tensor& input_b_colmajor, + const std::optional& offs, + const std::optional& /*bias*/, + at::Tensor& out) +{ + // Detect if input_a is row-major based on stride pattern + bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1); + bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1); + // Ensure tensor A is row-major and contiguous if not already + at::Tensor mat_a = input_a; + if (!a_row_major) { + // If A is not row-major, make it contiguous (row-major) + mat_a = input_a.contiguous(); + } + // Force tensor B to be column-major using double transpose trick + // This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape + at::Tensor mat_b = input_b_colmajor; + if (!b_col_major) { + mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1); + } + + // For 3D tensors, check the last dimension stride for row-major detection + a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1); + bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1); + + if (mat_a.dtype() == at::kBFloat16) { + // bf16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kHalf) { + // fp16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kFloat) { + // fp32 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype"); + } + +} + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 9aab5e350a93a..bc9454e1a6cc1 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -254,6 +254,476 @@ def _expand_to_batch(t: torch.Tensor): # cross comparison self.assertEqual(out1_gpu, out2_gpu[0]) +<<<<<<< HEAD +======= + def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist): + for a, b, gO, agrad, bgrad, out in zip(alist, blist, gOlist, agradlist, bgradlist, outlist): + a = a.clone().detach().requires_grad_() + b = b.clone().detach().requires_grad_() + out_ref = torch.mm(a, b.t()) + out_ref.backward(gO) + self.assertEqual(out, out_ref) + if agrad is not None: + self.assertEqual(agrad, a.grad) + self.assertEqual(bgrad, b.grad) + + @xfailIfSM120OrLater + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] + else: + a = torch.randn(k * n_groups + k * int(strided), m, device=device, dtype=dtype).t()[:, :k * n_groups] + + if b_row_major: + b = torch.randn(n, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] + else: + b = torch.randn(k * n_groups + k * int(strided), n, device=device, dtype=dtype).t()[:, :k * n_groups] + + a.requires_grad_(True) + b.requires_grad_(True) + offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) + + f = torch._grouped_mm + out = f(a, b.t(), offs=offs, out_dtype=dtype) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + alist, blist, agradlist, bgradlist = [], [], [], [] + start = 0 + for i in range(n_groups): + alist.append(a[:, start:offs_cpu[i]]) + blist.append(b[:, start:offs_cpu[i]]) + agradlist.append(a.grad[:, start:offs_cpu[i]]) + bgradlist.append(b.grad[:, start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) + + @xfailIfSM120OrLater + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] + else: + a = torch.randn(k, (m + 2 * s_int) * n_groups, device=device, dtype=dtype).t()[:m * n_groups, :] + + if b_row_major: + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.t() + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + for check_zero_size in (False, True): + if check_zero_size and n_groups <= 1: + continue + + a.grad = None + b.grad = None + offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32) + if check_zero_size: + offs[0] = offs[1] + + f = torch._grouped_mm + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) + gO = torch.rand_like(out) + if not check_zero_size: + out.backward(gO) + offs_cpu = offs.cpu() + alist, agradlist, gOlist, outlist = [], [], [], [] + bgradlist = [None] * n_groups if check_zero_size else b.grad + start = 0 + for i in range(n_groups): + alist.append(a[start:offs_cpu[i]]) + agradlist.append(None if check_zero_size else a.grad[start:offs_cpu[i]]) + outlist.append(out[start:offs_cpu[i]]) + gOlist.append(gO[start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist) + + + @xfailIfSM120OrLater + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + + f = torch._grouped_mm + out = f(a, b.transpose(-2, -1), out_dtype=dtype) + gO = torch.rand_like(out) + out.backward(gO) + self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) + + @xfailIfSM120OrLater + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] + else: + b = torch.randn(k, n * (n_groups + s_int), device=device, dtype=dtype).transpose(-2, -1)[:n * n_groups, :] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + for check_zero_size in (False, True): + if check_zero_size and n_groups <= 1: + continue + + offs = torch.arange(n, n_groups * n + 1, n, device=device, dtype=torch.int32) + if check_zero_size: + offs[0] = offs[1] + + f = torch._grouped_mm + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) + gO = torch.rand_like(out) + if not check_zero_size: + out.backward(gO) + offs_cpu = offs.cpu() + blist, outlist, bgradlist, gOlist = [], [], [], [] + agradlist = [None] * n_groups if check_zero_size else a.grad + start = 0 + for i in range(n_groups): + blist.append(b[start:offs_cpu[i]]) + bgradlist.append(b.grad[start:offs_cpu[i]]) + outlist.append(out[:, start:offs_cpu[i]]) + gOlist.append(gO[:, start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @xfailIfSM100OrLater + # TODO(future PR): enable compile for torch._grouped_mm fallback path + @unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90") + @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @parametrize("max_autotune", [False, True]) + def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune): + torch._dynamo.reset() + + device = "cuda" + dtype_AB = torch.bfloat16 + dtype_offset = torch.int32 + + align = 16 // dtype_AB.itemsize + + f_ref = torch._grouped_mm + + options = {} + if max_autotune: + options.update( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + } + ) + f = torch.compile( + f_ref, + options=options, + ) + + if op == "2d/2d": + m, n = 3, 7 + m_align = (m + align - 1) // align * align + n_align = (n + align - 1) // align * align + if not a_row_major and not b_row_major: + offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset) + else: + offs = torch.tensor([0, 8, 16, 16, 27], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + k = offs[-1] + k_align = (k + align - 1) // align * align + + if a_row_major: + A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] + if b_row_major: + B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] + elif op == "2d/3d": + n, k = 7, 259 # k is larger here, to validate iterating over k tiles on an op + n_align = (n + align - 1) // align * align + k_align = (k + align - 1) // align * align + if a_row_major: + offs = torch.tensor([0, 1, 3, 3, 5], device=device, dtype=dtype_offset) + else: + offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + m = offs[-1] + m_align = (m + align - 1) // align * align + + if a_row_major: + A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] + if b_row_major: + B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + B = torch.randn(ngroups, k, n_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :n, :] + elif op == "3d/2d": + m, k = 3, 13 + m_align = (m + align - 1) // align * align + k_align = (k + align - 1) // align * align + offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + n = offs[-1] + n_align = (n + align - 1) // align * align + + if a_row_major: + A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + A = torch.randn(ngroups, k, m_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :m, :] + if b_row_major: + B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] + elif op == "3d/3d": + offs = None + ngroups = 5 + m, n, k = 3, 7, 13 + m_align = (m + align - 1) // align * align + n_align = (n + align - 1) // align * align + k_align = (k + align - 1) // align * align + if a_row_major: + A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + A = torch.randn(ngroups, k, m_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :m, :] + if b_row_major: + B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + B = torch.randn(ngroups, k, n_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :n, :] + else: + raise AssertionError(f"Invalid op: {op}") + + C_ref = f_ref(A, B.transpose(-2, -1), offs=offs) + C = f(A, B.transpose(-2, -1), offs=offs) + torch.testing.assert_close(C, C_ref) + + + @onlyCUDA + @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @parametrize("M", [1, 32, 64]) + @parametrize("N", [1, 32, 64]) + @parametrize("K", [1, 32, 64]) + @parametrize("batch_size", [None, 1, 16]) + # TODO: enable rocblas path on ROCm + @parametrize("backend", ["cublaslt"] if torch.version.hip else ["cublas", "cublaslt"]) + def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): + device = "cuda" + dtype = input_dtype + with blas_library_context(backend): + def create_inputs(B=None): + if B is None: + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + else: + a = torch.randn(B, M, K, device=device, dtype=dtype) + b = torch.randn(B, K, N, device=device, dtype=dtype) + return a, b + + a, b = create_inputs(batch_size) + + a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32) + + output_dtypes = [torch.float32] + + if input_dtype != torch.float32: + output_dtypes.append(input_dtype) + + for output_dtype in output_dtypes: + # Catch edge case of incompat with bfloat16 and major version < 8 + if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: + if output_dtype == torch.bfloat16: + continue + + if batch_size: + with self.assertRaises(RuntimeError): + torch.bmm(a, b, out_dtype=output_dtype) + else: + with self.assertRaises(RuntimeError): + torch.mm(a, b, out_dtype=output_dtype) + else: + if batch_size: + out = torch.bmm(a, b, out_dtype=output_dtype) + baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b) + else: + out = torch.mm(a, b, out_dtype=output_dtype) + baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b) + + self.assertEqual(out.dtype, output_dtype) + + torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + + + @onlyCUDA + @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @parametrize("M", [1, 32, 64]) + @parametrize("N", [1, 32, 64]) + @parametrize("K", [1, 32, 64]) + @parametrize("batch_size", [None, 1, 32]) + # TODO: enable rocblas path on ROCm + @parametrize("backend", ["cublaslt"] if torch.version.hip else ["cublas", "cublaslt"]) + def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): + device = "cuda" + dtype = input_dtype + with blas_library_context(backend): + def create_inputs(B=None): + if B is None: + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + c = torch.randn(M, N, device=device, dtype=dtype) + else: + a = torch.randn(B, M, K, device=device, dtype=dtype) + b = torch.randn(B, K, N, device=device, dtype=dtype) + c = torch.randn(B, M, N, device=device, dtype=dtype) + + return a, b, c + + a, b, c = create_inputs(batch_size) + + a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32) + + output_dtypes = [torch.float32] + + if input_dtype != torch.float32: + output_dtypes.append(input_dtype) + + for output_dtype in output_dtypes: + # Catch edge case of incompat with bfloat16 and major version < 8 + if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: + if output_dtype == torch.bfloat16: + continue + + if batch_size: + with self.assertRaises(RuntimeError): + torch.baddbmm(c, a, b, out_dtype=output_dtype) + else: + with self.assertRaises(RuntimeError): + torch.addmm(c, a, b, out_dtype=output_dtype) + else: + if batch_size: + out = torch.baddbmm(c, a, b, out_dtype=output_dtype) + if output_dtype == torch.float32: + baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) + else: + baseline = torch.baddbmm(c, a, b) + else: + out = torch.addmm(c, a, b, out_dtype=output_dtype) + if output_dtype == torch.float32: + baseline = torch.addmm(c_fp32, a_fp32, b_fp32) + else: + baseline = torch.addmm(c, a, b) + + self.assertEqual(out.dtype, output_dtype) + torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + + + @onlyCUDA + @skipIfRocm + @parametrize("batch_size", [1, 32]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend): + M, N, K = 32, 32, 32 + device = "cuda" + dtype = torch.float16 + with blas_library_context(backend): + torch.backends.cuda.preferred_blas_library(backend) + + orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation + torch.backends.cuda.matmul.allow_fp16_accumulation = True + + def create_inputs(): + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + c = torch.randn(M, N, device=device, dtype=dtype) + return a, b, c + + def expand(tensor): + return tensor.unsqueeze(0).expand(batch_size, *tensor.shape) + + a, b, c = create_inputs() + + with self.assertRaises(Exception): + torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32) + + with self.assertRaises(Exception): + torch.addmm(c, a, b, out_dtype=torch.float32) + + with self.assertRaises(Exception): + torch.bmm(expand(a,), expand(b), out_dtype=torch.float32) + + with self.assertRaises(Exception): + torch.mm(a, b, out_dtype=torch.float32) + + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum +>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829)) f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"