Skip to content

Commit e9d7164

Browse files
authored
[release/2.9] Port group_gemm commits from upstream PT (#2829)
Ported below commits from upstream PT, modified code accordingly. [ROCm] enable grouped gemm fallback [ROCm] Enable group gemm through CK [ROCm] Disable group gemm CK path when composable kernel (CK) is not enabled Test command: PYTORCH_TEST_WITH_ROCM=1 pytest test/test_matmul_cuda.py -v -k "test_grouped_gemm_2d_2d or test_grouped_gemm_2d_3d or test_grouped_gemm_3d_3d or test_grouped_gemm_3d_2d" Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 9c8d5d1 commit e9d7164

File tree

4 files changed

+504
-18
lines changed

4 files changed

+504
-18
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include <ATen/native/cuda/RowwiseScaledMM.h>
2121
#include <ATen/native/cuda/ScaledGroupMM.h>
2222
#include <ATen/native/cuda/GroupMM.h>
23+
#ifdef USE_ROCM
24+
#include <ATen/native/hip/ck_group_gemm.h>
25+
#endif
2326
#include <ATen/ceil_div.h>
2427

2528
#ifdef USE_FBGEMM_GENAI
@@ -1083,16 +1086,6 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
10831086
#endif
10841087
}
10851088

1086-
static bool _grouped_mm_allowed_device() {
1087-
#ifdef USE_ROCM
1088-
return false;
1089-
#else
1090-
auto dprops = at::cuda::getCurrentDeviceProperties();
1091-
// CUDA capability 8.0 and greater
1092-
return dprops->major >= 8;
1093-
#endif
1094-
}
1095-
10961089
#ifdef USE_ROCM
10971090
static bool _scaled_mm_is_fnuz() {
10981091
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
@@ -1789,26 +1782,42 @@ Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
17891782
const std::optional<at::Tensor>& offs,
17901783
const std::optional<at::Tensor>& bias,
17911784
std::optional<c10::ScalarType> out_dtype) {
1792-
#ifndef USE_ROCM
17931785
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
17941786
bool a_b_and_out_are_bf16 = (
17951787
mat_a.dtype() == at::kBFloat16 &&
17961788
mat_b.dtype() == at::kBFloat16 &&
17971789
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
17981790
);
1791+
#ifndef USE_ROCM
17991792
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
1793+
#else
1794+
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
1795+
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
1796+
bool use_fast_path = false;
1797+
// On non CK system(w/ ROCm), make sure use_fast_path is false
1798+
#if defined(USE_ROCM_CK_GEMM)
1799+
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
1800+
use_fast_path = true;
1801+
}
1802+
#endif //USE_ROCM_CK_GEMM
1803+
#endif
18001804
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
18011805
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
18021806
if (use_fast_path) {
18031807
// fast path, no d2h sync needed
1808+
#ifndef USE_ROCM
18041809
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
1810+
#else
1811+
#if defined(USE_ROCM_CK_GEMM)
1812+
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
1813+
#else
1814+
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
1815+
#endif //USE_ROCM_CK_GEMM
1816+
#endif
18051817
} else {
18061818
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
18071819
}
18081820
return out;
1809-
#else
1810-
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
1811-
#endif
18121821
}
18131822

18141823
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <c10/core/ScalarType.h>
5+
#include <optional>
6+
7+
namespace at {
8+
namespace hip {
9+
namespace detail {
10+
void group_gemm_ck(
11+
const at::Tensor& mat_a,
12+
const at::Tensor& mat_b,
13+
const std::optional<at::Tensor>& offs,
14+
const std::optional<at::Tensor>& bias,
15+
at::Tensor& out);
16+
17+
} // namespace detail
18+
} // namespace hip
19+
} // namespace at

0 commit comments

Comments
 (0)