Skip to content

Commit 1fa520e

Browse files
jagadish-amdpytorchmergebot
authored andcommitted
[ROCm] Enable group gemm through CK (pytorch#166334)
Fixes pytorch#161366 All the 4 types of dimension matrix are supported. 2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working for both forward and backward pass. The CK path is enabled for gfx942, gfx950. ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error, might require a different CK kernel config, based on the profiler result on gfx90a. Pull Request resolved: pytorch#166334 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony
1 parent c2e3cc7 commit 1fa520e

File tree

4 files changed

+487
-2
lines changed

4 files changed

+487
-2
lines changed

aten/src/ATen/native/cuda/GroupedBlas.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#include <ATen/native/cuda/RowwiseScaledMM.h>
2323
#include <ATen/native/cuda/ScaledGroupMM.h>
2424
#include <ATen/native/cuda/GroupMM.h>
25+
#ifdef USE_ROCM
26+
#include <ATen/native/hip/ck_group_gemm.h>
27+
#endif
2528
#include <ATen/ceil_div.h>
2629

2730
#ifdef USE_FBGEMM_GENAI
@@ -636,12 +639,19 @@ std::optional<c10::ScalarType> out_dtype) {
636639
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
637640
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
638641
bool use_fast_path = false;
642+
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
643+
use_fast_path = true;
644+
}
639645
#endif
640646
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
641647
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
642648
if (use_fast_path) {
643649
// fast path, no d2h sync needed
650+
#ifndef USE_ROCM
644651
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
652+
#else
653+
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
654+
#endif
645655
} else {
646656
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
647657
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <c10/core/ScalarType.h>
5+
#include <optional>
6+
7+
namespace at {
8+
namespace hip {
9+
namespace detail {
10+
void group_gemm_ck(
11+
const at::Tensor& mat_a,
12+
const at::Tensor& mat_b,
13+
const std::optional<at::Tensor>& offs,
14+
const std::optional<at::Tensor>& bias,
15+
at::Tensor& out);
16+
17+
} // namespace detail
18+
} // namespace hip
19+
} // namespace at

0 commit comments

Comments
 (0)