Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>

#ifdef USE_FBGEMM_GENAI
Expand Down Expand Up @@ -1083,16 +1086,6 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
#endif
}

static bool _grouped_mm_allowed_device() {
#ifdef USE_ROCM
return false;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
// CUDA capability 8.0 and greater
return dprops->major >= 8;
#endif
}

#ifdef USE_ROCM
static bool _scaled_mm_is_fnuz() {
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
Expand Down Expand Up @@ -1789,26 +1782,42 @@ Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
bool a_b_and_out_are_bf16 = (
mat_a.dtype() == at::kBFloat16 &&
mat_b.dtype() == at::kBFloat16 &&
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
);
#ifndef USE_ROCM
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
#else
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
#endif
}

Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/native/hip/ck_group_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>

namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);

} // namespace detail
} // namespace hip
} // namespace at
Loading