diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index ad6f93b3e8..6b8146800d 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -42,6 +42,7 @@ if(CUTLASS_ENABLE_SYCL) xe_gemm_f8_f8_fp32_tensor_op_fp32.cpp xe_gemm_fp16_s8_fp32_tensor_op_fp32.cpp gemm_universal_bf16n_bf16t_f32n_tensor_op_f32_xe.cpp + gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index fb2e3dc6bd..24714067c4 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -4187,6 +4187,47 @@ bool TestXe( } // m return passed; } + +template class ActivationFunctor = + cutlass::epilogue::thread::Identity> +bool TestXe( + int m, int n, int k, int l, + double alpha = 1.0, + double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed( + check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); + + bool passed = true; + ProblemShapeType problem_size{m, n, k, l}; + try { + passed = testbed.run(problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestXe: testbed.run threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestXe: testbed.run threw an unknown exception"; + throw; + } + + EXPECT_TRUE(passed) << "TestXe: testbed.run failed for MNKL = " + << m << " " << n << " " << k << " " << l + << ", alpha: " << alpha << ", beta: " << beta; + + return passed; +} + #endif template diff --git a/test/unit/gemm/device/gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp b/test/unit/gemm/device/gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp new file mode 100644 index 0000000000..5dd1a308ce --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "default_gemm_configuration.hpp" +#include "gemm_testbed_3x.hpp" + +using namespace cutlass; + +namespace { + +template +struct MainloopIntelW8A8_GemmConfig { + using ElementA = float_e5m2_t; + using ElementB = float_e5m2_t; + using TileShape = Shape<_256, _256, _32>; + constexpr static int PipelineStages = 2; + using Schedule = gemm::KernelXe; + using TiledMma = typename TiledMMAHelper< + MMA_Atom, + Layout, + Layout, Stride<_4, _1, _0>> + >::TiledMMA; + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + using DispatchPolicy = gemm::MainloopIntelW8A8; + + using CollectiveMainloop = gemm::collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, cutlass::gemm::TagToStrideA_t, + ElementB, cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + float, float + >; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + cutlass::epilogue::IntelXeXMX16, + EpilogueOp, + TileShape, + decltype(tile_shape(TiledMma())) + >; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::IntelXeXMX16, + TileShape, + float, cutlass::gemm::TagToStrideC_t, + float, cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, void, void, + XE_2D_U32x8x16_ST_N, void, void + >; + + using GemmKernel = gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = gemm::device::GemmUniversalAdapter; +}; + +TEST(MainloopIntelW8A8_Special, LargeModel_LLaMA2_7B) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 4096, 11008, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeModel_Mistral_7B) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 4096, 14336, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, TensorParallel) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 1024, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, ModelParallel) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(1024, 4096, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, MicroBatch) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(128, 128, 8192, 4, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeBatch) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 512, 2048, 32, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, SquareSmall) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(64, 64, 64, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, SquareMedium) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 512, 512, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, SquareLarge) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(2048, 2048, 2048, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, TallMatrix) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 512, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, WideMatrix) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 4096, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, Batch8) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 512, 512, 8, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, Batch16Large) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(1024, 1024, 1024, 16, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeK) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(64, 64, 8192, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeN) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(64, 8192, 64, 1, 1.0, 0.0)); +} + +} // namespace