diff --git a/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp b/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp index c5b0233ac7..a691a1ed04 100644 --- a/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp +++ b/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp @@ -576,19 +576,22 @@ int launcher(Options& options) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; - using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + typename TiledMMAHelper< + MMA_Atom>, + Layout, + Layout, Stride<_4, _1, _0>> + >::TiledMMA; constexpr int PipelineStages = 2; // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +#include + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementOutput = float; // <- data type of elements in output matrix D + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +// Command line options parsing +struct Options { + + bool error = false; + bool help = false; + + float alpha, beta; + int iterations; + int m, n, k, groups; + std::vector problem_sizes_host; + + Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), + m(5120), n(4096), k(4096), groups(2) { + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("groups", groups, 2); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG Grouped GEMM\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "09_bmg_grouped_gemm_fp8" << " --m=5120 --n=4096 --k=4096 --groups=5 --alpha=2.5 --beta=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementAccumulator = ElementOutput; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation problem_sizes; + + // This example defines all matrices in a single allocation (e.g. block_A), but this is not a + // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + // Note, this is an array of pointers to alpha and beta scaling values per group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + + uint64_t seed = 0; + + // + // Methods + // + template + bool verify(const Options &options) { + bool passed = true; + // Verify against individual reference GEMMs + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + cutlass::DeviceAllocation block_A_fp16(block_A.size()); + cutlass::DeviceAllocation block_B_fp16(block_B.size()); + + // fp8 -> fp16 + convert_dtype( + block_A.get(), + block_A_fp16.get(), + block_A.size() + ); + convert_dtype( + block_B.get(), + block_B_fp16.get(), + block_B.size() + ); + + cutlass::TensorRef ref_A(block_A_fp16.get() + offset_A.at(i), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B_fp16.get() + offset_B.at(i), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), LayoutD::packed({M, N})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha_host.at(i), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta_host.at(i), + ref_C, + ref_D, + ElementAccumulator(0), + 1, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Wait for kernel to finish + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + if(!passed) + break; + } + return passed; + } + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +template +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast(rand() % 5 + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers + // (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + + /// Populates a Gemm::Arguments structure from the given commandline options + typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + + return arguments; + } + + template + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) { + allocate(options); + initialize(options); + + Gemm gemm_op; + + auto arguments = args_from_options(options, hw_info, host_problem_shapes_available); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + compat::wait(); + + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); + if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e4m3_t"<< std::endl; + } else if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e5m2_t"<< std::endl; + } else { + static_assert(cutlass::detail::dependent_false, "Not a valid fp8 datatype."); + } + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + } + + return cutlass::Status::kSuccess; + } + +}; + + +template +int launcher(Options& options) +{ + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementType, + cutlass::gemm::TagToStrideA_t, + ElementType, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::GroupScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.template run(options, hw_info)); + + return 0; +} + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + launcher(options); + launcher(options); + return 0; +} diff --git a/examples/09_bmg_grouped_gemm_f8/legacy/CMakeLists.txt b/examples/09_bmg_grouped_gemm_f8/legacy/CMakeLists.txt new file mode 100644 index 0000000000..90ce00dada --- /dev/null +++ b/examples/09_bmg_grouped_gemm_f8/legacy/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_GROUPS_2 --groups=2) +set(TEST_GROUPS_4 --groups=4) + +cutlass_example_add_executable( + 09_bmg_grouped_gemm_fp8_legacy + 09_bmg_grouped_gemm_fp8.cpp + TEST_COMMAND_OPTIONS + TEST_GROUPS_2 + TEST_GROUPS_4 +) +if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") + # TODO(codeplay): Remove these once IGC VectorAliasThreshold issue is fixed + target_link_options( 04_bmg_grouped_gemm PRIVATE -Xs "-options \"-igc_opts 'VectorAliasBBThreshold=10000'\"" ) +endif() \ No newline at end of file diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 38c6b34b75..9d1533257d 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -111,6 +111,7 @@ if(CUTLASS_ENABLE_SYCL) 07_bmg_dual_gemm 08_bmg_gemm_f8 09_bmg_grouped_gemm_f8 + 09_bmg_grouped_gemm_f8/legacy 10_bmg_grouped_gemm_mixed_dtype ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp b/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp index 36f8c85587..acd40519f0 100644 --- a/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp @@ -48,13 +48,13 @@ using namespace cute; template -struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { // // Type Aliases // - using DispatchPolicy = MainloopIntelXeXMX16GroupFP8; + using DispatchPolicy = MainloopXeL1StagedGroup; using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -74,7 +74,7 @@ struct CollectiveMma, TileShape_, using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); + static_assert(platform::is_same::value, "MainloopXeL1StagedGroup requires that A and B have same type."); static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); diff --git a/include/cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp b/include/cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp new file mode 100644 index 0000000000..36f8c85587 --- /dev/null +++ b/include/cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cutlass/fp8_to_fp16.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelXeXMX16GroupFP8; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); + static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr int BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr int BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr int BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr int ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr int SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr int SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr int SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape, C, C>; + + static constexpr int Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using Copy_A = typename Copy_Traits::template DefaultTiledCopy; + using Copy_B = typename Copy_Traits::template DefaultTiledCopy; + + using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) + using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) + using MainloopTensors = cute::tuple; + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + struct Params { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + const int32_t mock_L = 1; + auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, mock_L);; + auto init_M = get<0>(problem_shape_MNK); + auto init_N = get<1>(problem_shape_MNK); + auto init_K = get<2>(problem_shape_MNK); + + return Params{ + args.ptr_A, + args.dA, + args.ptr_B, + args.dB + }; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (L > 1) { + implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int const& k_tile_count, + BlkCoord const &blk_coord, int const &K_start, int const& thread_idx, + Params const &mainloop, LoadTensors const& load_tensors) { + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + (void)thread_idx; + + Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; + Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; + + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + // TODO(Codeplay): see if we can make this nicer + // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition global counting tensors for MMA + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); + + Tensor tCrA_fp16 = make_fragment_like(tCrA); + Tensor tCrB_fp16 = make_fragment_like(tCrB); + + // Retile registers for copies + Tensor tArA = thr_copy_A.retile_D(tCrA); + Tensor tBrB = thr_copy_B.retile_D(tCrB); + + // Retile global counting tensors for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a); + auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b); + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); print(gA); print("\n"); + print("tCgA : "); print(tCgA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + + print("===================== B :\n"); + print(" gB : "); print(gB); print("\n"); + print("tCgB : "); print(tCgB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); + print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + } +#endif + + // + // Mainloop + // + const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + // Copy gmem to rmem for the first k_tile + copy(tiled_copy_a, tAgA(_,_,_,k_tile), tArA); + copy(tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); + + convert_FP8_to_FP16(tCrA, tCrA_fp16); + convert_FP8_to_FP16(tCrB, tCrB_fp16); + + if (prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + cute::gemm(tiled_mma, tCrA_fp16, tCrB_fp16, accum); + barrier_wait(barrier_scope); + } + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + Params const& mainloop_params, + int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + const int32_t M = get<0>(problem_shape_mnkl); + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); + ElementB const* ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[next_group]); + + Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K,(int32_t)1), mainloop_params.dA[next_group]); + Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K,(int32_t)1), mainloop_params.dB[next_group]); + + return cute::make_tuple(mA, mB); + } +}; + +} // namespace cutlass::gemm::collective + +/////////////////////////////////////////////////////////////////////////////////////////////////