Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
#include <c10/util/MaybeOwned.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
<<<<<<< HEAD
=======
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>

#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829))

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -946,9 +958,12 @@ static bool _scaled_mm_allowed_device() {
return true;
}
}
<<<<<<< HEAD
return false;
#else
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
=======
>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829))
#endif
}

Expand Down Expand Up @@ -1557,9 +1572,53 @@ bool use_fast_accum) {



<<<<<<< HEAD
#else
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
#endif
=======
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
bool a_b_and_out_are_bf16 = (
mat_a.dtype() == at::kBFloat16 &&
mat_b.dtype() == at::kBFloat16 &&
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
);
#ifndef USE_ROCM
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
#else
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
return out;
}
>>>>>>> e9d7164566 ([release/2.9] Port group_gemm commits from upstream PT (#2829))

}

Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/native/hip/ck_group_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>

namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);

} // namespace detail
} // namespace hip
} // namespace at
Loading