diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 49356f8c79bc8..83f2b74f38969 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -20,6 +20,9 @@ #include #include #include +#ifdef USE_ROCM +#include +#endif #include #ifdef USE_FBGEMM_GENAI @@ -1083,16 +1086,6 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals #endif } -static bool _grouped_mm_allowed_device() { -#ifdef USE_ROCM - return false; -#else - auto dprops = at::cuda::getCurrentDeviceProperties(); - // CUDA capability 8.0 and greater - return dprops->major >= 8; -#endif -} - #ifdef USE_ROCM static bool _scaled_mm_is_fnuz() { return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); @@ -1789,26 +1782,42 @@ Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, const std::optional& offs, const std::optional& bias, std::optional out_dtype) { -#ifndef USE_ROCM _grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype); bool a_b_and_out_are_bf16 = ( mat_a.dtype() == at::kBFloat16 && mat_b.dtype() == at::kBFloat16 && out_dtype.value_or(at::kBFloat16) == at::kBFloat16 ); +#ifndef USE_ROCM bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16; +#else + // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used. + // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm + bool use_fast_path = false; + // On non CK system(w/ ROCm), make sure use_fast_path is false +#if defined(USE_ROCM_CK_GEMM) + if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) { + use_fast_path = true; + } +#endif //USE_ROCM_CK_GEMM +#endif const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); if (use_fast_path) { // fast path, no d2h sync needed +#ifndef USE_ROCM at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); +#else +#if defined(USE_ROCM_CK_GEMM) + at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); +#else + TORCH_WARN("ROCm: Group Gemm through CK not selected."); +#endif //USE_ROCM_CK_GEMM +#endif } else { _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); } return out; -#else - TORCH_CHECK(false, "grouped gemm is not supported on ROCM") -#endif } Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) { diff --git a/aten/src/ATen/native/hip/ck_group_gemm.h b/aten/src/ATen/native/hip/ck_group_gemm.h new file mode 100644 index 0000000000000..c50307c9f8ea3 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace hip { +namespace detail { +void group_gemm_ck( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + const std::optional& bias, + at::Tensor& out); + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/aten/src/ATen/native/hip/ck_group_gemm.hip b/aten/src/ATen/native/hip/ck_group_gemm.hip new file mode 100644 index 0000000000000..c436ad660c1c7 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.hip @@ -0,0 +1,462 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at { +namespace hip { +namespace detail { + +namespace CkTypes { + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using F32 = float; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; +} + +template +using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< + ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, + DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType, + CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + 1, 1, + S<1,32,1,8>, 4 +>; + +template +void launch_grouped_bgemm_ck_impl_dispatch( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + at::Tensor& out) +{ + using DeviceOp = GroupedGemmKernel; + using PassThrough = CkTypes::PassThrough; + + std::vector gemm_descs; + std::vector p_a_ptrs, p_b_ptrs; + std::vector p_e_ptrs; + // Note: d_ptrs will be resized after we populate the other vectors + + const int mat_a_dim = mat_a.dim(); + const int mat_b_dim = mat_b.dim(); + + const char* a_ptr_base = reinterpret_cast(mat_a.data_ptr()); + const char* b_ptr_base = reinterpret_cast(mat_b.data_ptr()); + char* out_ptr_base = reinterpret_cast(out.data_ptr()); + const size_t a_element_size = mat_a.element_size(); + const size_t b_element_size = mat_b.element_size(); + const size_t out_element_size = out.element_size(); + + // for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses. + if (mat_a_dim == 2 && mat_b_dim == 2) { + // 2D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + const int M = mat_a.size(0); // number of rows in A + const int N = mat_b.size(1); // number of columns in B + const int K = mat_a.size(1); // columns in A == rows in B + // for 2d*2d input, output is 3d. + // for each group, A columns (K) are sliced. M and N dimensions are not sliced. + for (int i = 0; i < num_groups; ++i) { + int start_k = (i == 0) ? 0 : offs_accessor[i-1]; + int end_k = offs_accessor[i]; + int k = end_k - start_k; + + //K dimension are sliced, hence select stride(1) always. + //K dimension is always dimension 1, regardless of memory layout (row/column major) + const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size; + const void* group_b_ptr; + int ldb; + + if (std::is_same::value) { + // Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size; + // Leading dimension = distance between rows = stride(0) + ldb = mat_b.stride(0); + } else { + // Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size; + // Leading dimension = distance between columns = stride(1) + ldb = mat_b.stride(1); + } + + // Calculate output pointer for group i in 3D tensor [num_groups, M, N] + // stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + int lda, ldc; + if (std::is_same::value) { + // Row-major A [M,K]: leading dimension = distance between rows = stride(0) + lda = mat_a.stride(0); + } else { + // Column-major A [M,K]: leading dimension = distance between columns = stride(1) + lda = mat_a.stride(1); + } + // Output is always row-major in 3D tensor [num_groups, M, N] + // Leading dimension for each group's [M,N] slice = stride(1) = N + ldc = out.stride(1); + size_t output_group_bytes = M * N * out_element_size; + void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes; + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(k), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 2 && mat_b_dim == 3) { + // 2D*3D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + + // 2d*3d input, output is 2d. + // A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n] + // Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B + const int K = mat_a.size(1); // columns in A + // For 2D-3D case: The output determines N (result width) + const int N = out.size(1); // N is the width of the output tensor + + for (int i = 0; i < num_groups; ++i) { + int start_m = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_m = offs_accessor[i]; + int m = end_m - start_m; + + // Skip zero-sized groups but continue processing subsequent groups + if (m <= 0) { + continue; + } + + // Select A rows for group i: skip start_m rows + const void* group_a_ptr; + int lda; + if (std::is_same::value) { + // Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + lda = mat_a.stride(0); // distance between rows + } else { + // Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows) + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + + // Detect stride pattern for A tensor to determine appropriate lda calculation + bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0)); + + if (a_is_strided_tensor) { + // For strided A tensors: stride(0) gives the actual leading dimension + lda = mat_a.stride(0); + } else { + // For non-strided A tensors: use the M dimension (total rows) + lda = mat_a.size(0); // Total M dimension for column-major layout + } + } + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + int ldb; + + if (std::is_same::value) { + // Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed + ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N] + } else { + // Detect stride pattern to determine appropriate ldb calculation + bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2)); + + if (is_strided_tensor) { + // For strided tensors: stride(2) gives the actual leading dimension + ldb = mat_b.stride(2); + } else { + // For non-strided tensors: use the N dimension + ldb = mat_b.size(1); + } + } + + // Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N] + void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size; + int ldc = out.stride(0); // distance between rows in output (should be N for 2D case) + + gemm_descs.push_back({ + static_cast(m), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 3) { + // 3d*3d input, output is 3d - batched matrix multiplication + // A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n] + // Each batch is processed as a separate GEMM operation + const int batch_size = mat_a.size(0); + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed) + + // Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout + int N; + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + N = mat_b.size(2); + } else if (mat_b.size(2) == K) { + // B is [batch, n, k] - transposed layout + N = mat_b.size(1); + } else { + TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[", + batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]"); + } + + for (int i = 0; i < batch_size; ++i) { + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + + // Select output batch for group i: Output[i, :, :] + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + + int lda, ldb, ldc; + + if (std::is_same::value) { + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + } else { + // Column-major A: leading dimension = distance between columns = stride(2) + lda = mat_a.stride(2); + } + + if (std::is_same::value) { + // Row-major B: leading dimension = distance between rows + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(1); // stride between K rows + } else { + // B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM + ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n]) + } + } else { + // Column-major B: leading dimension = distance between columns + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(2); // stride between N columns + } else { + // B is [batch, n, k] - transposed layout + ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n]) + } + } + + // Output is typically row-major: leading dimension = distance between rows = stride(1) + ldc = out.stride(1); + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 2) { + // 3D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + // 3d*2d input, output is 3d. + // A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both) + // Offset divides N dimension of B, each group gets different slice of B and different batch of A + const int batch_size = mat_a.size(0); // n_groups + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A + + // For row-major A and B case: B should be [K, total_N] + const int total_N = mat_b.size(1); // B is [K, total_N] for row-major + + for (int i = 0; i < num_groups; ++i) { + int start_n = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_n = offs_accessor[i]; + int n = end_n - start_n; + + // Skip zero-sized groups but continue processing subsequent groups + if (n <= 0) { + continue; + } + + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B slice for group i: B[:, start_n:end_n] (B[K, total_N]) + const void* group_b_ptr; + int ldb; + + // Check if B is row-major or column-major + if (std::is_same::value) { + // Row-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(0); // distance between rows (should be total_N) + } else { + // Column-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(1); // distance between columns (should be K) + } + + // Select output slice for group i: Output[:, start_n:end_n] + void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size; + + int lda, ldc; + + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + // Output is row-major: leading dimension = distance between rows = stride(0) + ldc = out.stride(0); + + gemm_descs.push_back({ + static_cast(M), + static_cast(n), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim); + } + + TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups"); + + // Initialize d_ptrs with the correct size + std::vector> d_ptrs(p_a_ptrs.size()); + + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), + "CK Group GEMM: argument unsupported (shape/strides/type config)"); + size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); + size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); + + void* gemm_arg_buf = nullptr; + void* ws_buf = nullptr; + + hipMalloc(&gemm_arg_buf, arg_buf_size); + hipMalloc(&ws_buf, ws_size); + + gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); + gemm_instance.SetWorkSpacePointer(&argument, ws_buf); + + auto invoker = gemm_instance.MakeInvoker(); + hipStream_t stream = c10::hip::getCurrentHIPStream(); + invoker.Run(argument, {stream}); + hipFree(gemm_arg_buf); + hipFree(ws_buf); +} + +void group_gemm_ck( + const at::Tensor& input_a, + const at::Tensor& input_b_colmajor, + const std::optional& offs, + const std::optional& /*bias*/, + at::Tensor& out) +{ + // Detect if input_a is row-major based on stride pattern + bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1); + bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1); + // Ensure tensor A is row-major and contiguous if not already + at::Tensor mat_a = input_a; + if (!a_row_major) { + // If A is not row-major, make it contiguous (row-major) + mat_a = input_a.contiguous(); + } + // Force tensor B to be column-major using double transpose trick + // This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape + at::Tensor mat_b = input_b_colmajor; + if (!b_col_major) { + mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1); + } + + // For 3D tensors, check the last dimension stride for row-major detection + a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1); + bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1); + + if (mat_a.dtype() == at::kBFloat16) { + // bf16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kHalf) { + // fp16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kFloat) { + // fp32 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype"); + } + +} + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 175e6a9649cd2..fe8ae6f77d475 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -316,7 +316,6 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) self.assertEqual(agrad, a.grad) self.assertEqual(bgrad, b.grad) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -355,7 +354,6 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): start = offs_cpu[i] self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -412,7 +410,6 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -447,7 +444,6 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): out.backward(gO) self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True])