Skip to content

Commit 20e2308

Browse files
authored
[BENCHMARK][GEMM] Re-enable CUTLASS's edge cases (#4181)
This PR updates the CUTLASS pin to be in sync with this [PR 364](intel/sycl-tla#364), that fixes the edge cases encountered with CUTLASS in our integration. Thanks to this update, we can now also re-enable the previously disabled CUTLASS edge cases. --------- Signed-off-by: Jefferson Le Quellec <[email protected]>
1 parent d10bbc5 commit 20e2308

File tree

4 files changed

+17
-25
lines changed

4 files changed

+17
-25
lines changed

benchmarks/cmake/FindCUTLASSLibrary.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ if (NOT CUTLASSLibrary_FOUND)
99
set(CUTLASSLibrary_SOURCE_DIR
1010
"${CMAKE_CURRENT_BINARY_DIR}/CUTLASSLibrary")
1111
message(STATUS "CUTLASSLibrary is not specified. Will try to download
12-
CUTLASS library from https://github.com/codeplaysoftware/cutlass-fork into
12+
CUTLASS library from https://github.com/codeplaysoftware/cutlass-sycl.git into
1313
${CUTLASSLibrary_SOURCE_DIR}")
1414
file(READ cutlass_kernel/cutlass-library.conf CUTLASSLibrary_TAG)
1515
# Strip the potential trailing newline from tag
1616
string(STRIP "${CUTLASSLibrary_TAG}" CUTLASSLibrary_TAG)
1717
FetchContent_Declare(cutlass-library
18-
GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-fork
18+
GIT_REPOSITORY https://github.com/codeplaysoftware/cutlass-sycl.git
1919
GIT_TAG ${CUTLASSLibrary_TAG}
2020
SOURCE_DIR ${CUTLASSLibrary_SOURCE_DIR}
2121
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f36600cfc1c3a60d9b1d6f64f946a1e877ea33bd
1+
bb48e86d2fe7cb09eab2e719e78d5811d3da3131

benchmarks/cutlass_kernel/python_main.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
5656

5757
using CollectiveMainloop =
5858
typename cutlass::gemm::collective::CollectiveBuilder<
59-
cutlass::arch::IntelPVC, cutlass::arch::OpClassTensorOp,
60-
ElementInputA, LayoutA, AlignmentA, ElementInputB, LayoutB,
61-
AlignmentB, ElementAccumulator, TileShape,
59+
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, ElementInputA,
60+
LayoutA, AlignmentA, ElementInputB, LayoutB, AlignmentB,
61+
ElementAccumulator, TileShape,
6262
cute::Shape<cute::_1, cute::_1, cute::_1>,
6363
cutlass::gemm::collective::StageCountAuto,
6464
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
@@ -71,7 +71,7 @@ static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
7171
cutlass::FloatRoundStyle::round_to_nearest>;
7272
using CollectiveEpilogue =
7373
typename cutlass::epilogue::collective::CollectiveBuilder<
74-
cutlass::arch::IntelPVC, cutlass::arch::OpClassTensorOp, TileShape,
74+
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape,
7575
cute::Shape<cute::_1, cute::_1, cute::_1>,
7676
cutlass::epilogue::collective::EpilogueTileAuto,
7777
ElementComputeEpilogue, ElementAccumulator, ElementAccumulator,
@@ -184,9 +184,8 @@ auto gemm(const at::Tensor &A, const at::Tensor &B, at::Tensor &C, const int M,
184184
return gemm_run<TileShape_RRR_3>(A, B, C, M, N, K, L);
185185
if (test_case == Dim{4096, 8, 128, 16384})
186186
return gemm_run<TileShape_RRR_4>(A, B, C, M, N, K, L);
187-
/// FIXME: Getting a compile time error for RRR_5
188-
// if (test_case == Dimension{4096, 8, 16384, 128})
189-
// return gemm_run<TileShape_RRR_5>(A, B, C, M, N, K, L);
187+
if (test_case == Dim{4096, 8, 16384, 128})
188+
return gemm_run<TileShape_RRR_5>(A, B, C, M, N, K, L);
190189

191190
return gemm_run<TileShape_RRR_1>(A, B, C, M, N, K, L);
192191
}

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -433,21 +433,14 @@ def cutlass_invoker():
433433
cutlass_fn = cutlass_invoker
434434
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
435435

436-
# FIXME: Remove temporary condition when https://github.com/codeplaysoftware/cutlass-fork/pull/313 will be merged
437-
if (B, M, N, K) == (4096, 8, 128, 16384) or (B, M, N, K) == (4096, 8, 16384, 128):
438-
min_ms = float('nan')
439-
max_ms = float('nan')
440-
mean_ms = float('nan')
441-
cv = float('nan')
442-
else:
443-
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
444-
benchmark_suite.assert_close(cutlass_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='cutlass to torch')
445-
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(
446-
cutlass_fn,
447-
n_warmup=10,
448-
n_repeat=10,
449-
quantiles=quantiles,
450-
)
436+
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
437+
benchmark_suite.assert_close(cutlass_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='cutlass to torch')
438+
_, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench(
439+
cutlass_fn,
440+
n_warmup=10,
441+
n_repeat=10,
442+
quantiles=quantiles,
443+
)
451444

452445
else:
453446
raise NotImplementedError(f'Unsupported provider {provider}')

0 commit comments

Comments
 (0)