Skip to content

Commit e4a80a3

Browse files
committed
Refactor: move grouped GEMM to separate file and cleanup API
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent 1167f75 commit e4a80a3

File tree

5 files changed

+635
-555
lines changed

5 files changed

+635
-555
lines changed

tests/cpp/operator/test_grouped_gemm.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* See LICENSE for license information.
55
************************************************************************/
66

7+
#include <cublasLt.h>
78
#include <cuda_bf16.h>
89
#include <cuda_runtime.h>
910
#include <gtest/gtest.h>
@@ -314,9 +315,12 @@ std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
314315
}
315316

316317
void run_grouped_gemm_case(const TestParams& params) {
317-
if (params.input_case != InputCase::kBF16 &&
318-
getDeviceComputeCapability() < hopperComputeCapability) {
319-
GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer.";
318+
#if CUBLAS_VERSION < 130200
319+
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is "
320+
<< CUBLAS_VERSION << ".";
321+
#else
322+
if (getDeviceComputeCapability() < hopperComputeCapability) {
323+
GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer.";
320324
}
321325

322326
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
@@ -451,7 +455,6 @@ void run_grouped_gemm_case(const TestParams& params) {
451455
grouped_D.get_handle(),
452456
setup_ws.data(),
453457
cublas_ws.data(),
454-
nullptr,
455458
0,
456459
nullptr,
457460
nullptr,
@@ -477,6 +480,7 @@ void run_grouped_gemm_case(const TestParams& params) {
477480
atol,
478481
rtol);
479482
}
483+
#endif // CUBLAS_VERSION >= 130200
480484
}
481485

482486
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};

0 commit comments

Comments
 (0)